From 5adb7bf1a4969be003fe1b9cd57d70e0e062609f Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Fri, 13 Aug 2021 23:15:34 +0530 Subject: [PATCH] Combined variants to reduce redundancy in dtrsm small 1. Left Lower non-trans,Left Upper trans 2. Left Upper non-trans,Left Lower trans 3. Right Lower non-trans.Right Upper trans 4. Right Upper non-trans,Right Lower trans Change-Id: I0b0155d7c3a55ec74d53c8f1f49f1bceb63b15f5 --- kernels/zen/3/bli_trsm_small.c | 25028 ++++++++----------------------- 1 file changed, 6314 insertions(+), 18714 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index b127fa4e7..ea9de2a88 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -37,9 +37,6 @@ #include "immintrin.h" #define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL -#define D_MR 8 -#define D_NR 6 - /* declaration of trsm small kernels function pointer @@ -55,7 +52,10 @@ typedef err_t (*trsmsmall_ker_ft) //AX = B; A is lower triangular; No transpose; //double precision; non-unit diagonal -BLIS_INLINE err_t bli_dtrsm_small_AlXB +//A.'X = B; A is upper triangular; +//A has to be transposed; double precision + +BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ( obj_t* AlphaObj, obj_t* a, @@ -68,7 +68,37 @@ BLIS_INLINE err_t bli_dtrsm_small_AlXB * A is upper-triangular, non-transpose, non-unit diagonal * dimensions A: mxm X: mxn B: mxn */ -BLIS_INLINE err_t bli_dtrsm_small_AuXB +//AX = B; A is lower triangular; transpose; double precision + +BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +//XA = B; A is upper-triangular; A is transposed; +//double precision; non-unit diagonal +// XA = B; A is lower-traingular; No transpose; +//double precision; non-unit diagonal + +BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +); + +// XA = B; A is upper triangular; No transpose; +//double presicion; non-unit diagonal +//XA = B; A is lower-triangular; A is transposed; +// double precision; non-unit-diagonal + +BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ( obj_t* AlphaObj, obj_t* a, @@ -90,28 +120,7 @@ BLIS_INLINE err_t dtrsm_AltXB_ref bool is_unitdiag ); -//A.'X = B; A is upper triangular; -//A has to be transposed; double precision -BLIS_INLINE err_t bli_dtrsm_small_AutXB -( - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -//AX = B; A is lower triangular; transpose; double precision -BLIS_INLINE err_t bli_dtrsm_small_AltXB -( - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -/* +/* * The preinversion of diagonal elements are enabled/disabled * based on configuration. */ @@ -146,16 +155,16 @@ BLIS_INLINE err_t dtrsm_AutXB_ref dim_t i, j, k; for (k = 0; k < M; k++) { - double lkk_inv = 1.0; - if(!unitDiagonal) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = 0; j < N; j++) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb] , lkk_inv); - for (i = k+1; i < M; i++) - { - B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; - } - } + double lkk_inv = 1.0; + if(!unitDiagonal) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); + for (j = 0; j < N; j++) + { + B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb] , lkk_inv); + for (i = k+1; i < M; i++) + { + B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; + } + } }// k -loop return BLIS_SUCCESS; }// end of function @@ -178,16 +187,16 @@ BLIS_INLINE err_t dtrsm_AuXB_ref dim_t i, j, k; for (k = M-1; k >= 0; k--) { - double lkk_inv = 1.0; - if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = N -1; j >= 0; j--) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); - for (i = k-1; i >=0; i--) - { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; - } - } + double lkk_inv = 1.0; + if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); + for (j = N -1; j >= 0; j--) + { + B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); + for (i = k-1; i >=0; i--) + { + B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; + } + } }// k -loop return BLIS_SUCCESS; }// end of function @@ -256,50 +265,6 @@ BLIS_INLINE err_t dtrsm_AltXB_ref return BLIS_SUCCESS; }// end of function -// XA = B; A is lower-traingular; No transpose; -//double precision; non-unit diagonal -BLIS_INLINE err_t bli_dtrsm_small_XAlB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -//XA = B; A is lower-triangular; A is transposed; -// double precision; non-unit-diagonal -BLIS_INLINE err_t bli_dtrsm_small_XAltB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -// XA = B; A is upper triangular; No transpose; -//double presicion; non-unit diagonal -BLIS_INLINE err_t bli_dtrsm_small_XAuB -( - obj_t* alpha, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - -//XA = B; A is upper-triangular; A is transposed; -//double precision; non-unit diagonal -BLIS_INLINE err_t bli_dtrsm_small_XAutB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -); - /* TRSM scalar code for the case XA = alpha * B * A is upper-triangular, non-unit/unit diagonal no transpose * Dimensions: X:mxn A:nxn B:mxn @@ -341,7 +306,6 @@ BLIS_INLINE err_t dtrsm_XAlB_ref ( double *A, double *B, - double alpha, dim_t M, dim_t N, dim_t lda, @@ -350,13 +314,6 @@ BLIS_INLINE err_t dtrsm_XAlB_ref ) { dim_t i, j, k; - for(j = 0; j < N; j++) - { - for(i = 0; i < M; i++) - { - B[i+j*ldb] *= alpha; - } - } for(k = N;k--;) { @@ -383,7 +340,6 @@ BLIS_INLINE err_t dtrsm_XAutB_ref ( double *A, double *B, - double alpha, dim_t M, dim_t N, dim_t lda, @@ -392,13 +348,6 @@ BLIS_INLINE err_t dtrsm_XAutB_ref ) { dim_t i, j, k; - for(j = 0; j < N; j++) - { - for(i = 0; i < M; i++) - { - B[i+j*ldb] *=alpha; - } - } for(k = N; k--;) { @@ -474,7 +423,7 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm15 = _mm256_setzero_pd(); /*GEMM block used in trsm small right cases*/ -#define BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) \ for(k = 0; k < k_iter; k++) \ {\ /*load 8x1 block of B10*/ \ @@ -512,8 +461,199 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b; \ } -/*GEMM block used in trsm small left cases*/ -#define BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) \ +#define BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); /*A01[0][4]*/\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); /*A01[0][5]*/\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter)\ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter)\ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter)\ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4));/*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/\ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 8x1 block of B10*/\ + ymm0 = _mm256_loadu_pd((double const *)b10);/*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + +/*GEMM block used in dtrsm small left cases*/ +#define BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) \ double *b01_prefetch = b01 + 8; \ for(k = 0; k< k_iter; k++) \ { \ @@ -554,6 +694,747 @@ BLIS_INLINE err_t dtrsm_XAltB_ref a10 += p_lda; \ } +#define BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12);\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4));\ + ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5));\ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3));\ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2));\ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1));\ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) \ + for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + ymm0 = _mm256_loadu_pd((double const *)(a10));\ +\ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0));\ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8);\ +\ + b01 += 1; /*move to next row of B*/\ + a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\ + } + +/* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform DTRSM operation for left cases. +*/ +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal));\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2));\ + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3));\ + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11);\ +\ + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); \ + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); \ + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); \ + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31);\ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); \ + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); \ + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); \ + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); \ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + 4));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4));\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4));\ + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4));\ + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12);\ + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13);\ + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14);\ + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15);\ +\ + ymm13 = _mm256_unpacklo_pd(ymm0, ymm1);\ + ymm15 = _mm256_unpacklo_pd(ymm2, ymm3);\ + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20);\ + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31);\ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1);\ + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3);\ +\ + ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);\ + ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5));\ + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4);\ + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5);\ + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4));\ + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4));\ + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6);\ + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7);\ +\ + ymm16 = _mm256_broadcast_sd((double const *)(&ones));\ + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1);\ + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20);\ + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);\ +\ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1);\ + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20);\ + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31);\ + ymm18 = _mm256_unpacklo_pd(ymm2, ymm3);\ + ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20);\ + ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31);\ +\ + /*unpackhigh*/\ + ymm20 = _mm256_unpackhi_pd(ymm2, ymm3);\ +\ + /*rearrange high elements*/\ + ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20);\ + ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); + +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b)\ + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9);\ + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11);\ +\ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20);\ + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31);\ +\ + /*unpack high*/\ + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9);\ + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11);\ +\ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20);\ + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3);\ +\ + /*unpacklow*/\ + ymm1 = _mm256_unpacklo_pd(ymm12, ymm13);\ + ymm3 = _mm256_unpacklo_pd(ymm14, ymm15);\ +\ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20);\ + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31);\ +\ + /*unpack high*/\ + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13);\ + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15);\ +\ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20);\ + ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31);\ +\ + _mm256_storeu_pd((double *)(b11 + 4), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3);\ +\ + /*unpacklow*/\ + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5);\ + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7);\ +\ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20);\ +\ + /*unpack high*/\ + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5);\ + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7);\ +\ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20);\ +\ + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1);\ +\ + /*unpacklow*/\ + ymm1 = _mm256_unpacklo_pd(ymm17, ymm18);\ + ymm3 = _mm256_unpacklo_pd(ymm19, ymm20);\ +\ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20);\ +\ + /*unpack high*/\ + ymm17 = _mm256_unpackhi_pd(ymm17, ymm18);\ + ymm18 = _mm256_unpackhi_pd(ymm19, ymm20);\ +\ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20);\ +\ + _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0);\ + _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); + +#define BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2));\ + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);\ + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ + xmm5 = _mm256_extractf128_pd(ymm2, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5);\ + _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ + xmm5 = _mm256_extractf128_pd(ymm1, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5);\ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);\ +\ + xmm5 = _mm256_extractf128_pd(ymm0, 0);\ + _mm_storeu_pd((double *)(b11), xmm5);\ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); + + +#define BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C);\ + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); /*store(B11[0-3][1])*/\ + xmm5 = _mm256_extractf128_pd(ymm2, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0));\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C);\ +\ + _mm256_storeu_pd((double *)(b11), ymm0); /*store(B11[0-3][0])*/\ + xmm5 = _mm256_extractf128_pd(ymm1, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C);\ +\ + xmm5 = _mm256_extractf128_pd(ymm0, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1));\ + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2));\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);\ +\ + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0));\ + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0));\ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1));\ +\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);\ +\ + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b)\ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0));\ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8);\ +\ + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);\ +\ + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); + +/* pre & post TRSM for Right remainder cases*/ +#define BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + +#define BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07);\ +\ + _mm256_storeu_pd((double *)b11, ymm3);\ + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ + xmm5 = _mm256_extractf128_pd(ymm7, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5);\ + _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + +#define BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03);\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03);\ +\ + _mm256_storeu_pd((double *)b11, ymm3);\ + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5);\ + xmm5 = _mm256_extractf128_pd(ymm7, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); + +#define BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + +#define BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b)\ + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01);\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01);\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01);\ +\ + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0));\ + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + +#define BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + +#define BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07);\ +\ + _mm256_storeu_pd((double *)b11, ymm3);\ + xmm5 = _mm256_extractf128_pd(ymm5, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5);\ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); + +#define BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + +#define BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03);\ + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03);\ +\ + _mm256_storeu_pd((double *)b11, ymm3);\ + xmm5 = _mm256_extractf128_pd(ymm5, 0);\ + _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + +#define BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b)\ + ymm0 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01);\ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01);\ +\ + _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0));\ + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2));\ + ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + +#define BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b)\ + xmm5 = _mm256_extractf128_pd(ymm3, 0);\ + _mm_storeu_pd((double *)(b11), xmm5);\ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); + +#define BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const*)(b11));\ + ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + +#define BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b)\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03);\ +\ + xmm5 = _mm256_extractf128_pd(ymm3, 0);\ + _mm_storeu_pd((double *)(b11), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm6 = _mm256_broadcast_sd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); + +#define BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b)\ + ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01);\ +\ + _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); + +/* multiply with Alpha pre TRSM for 6*8 kernel*/ +#define BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4));\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4));\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4));\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4));\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4));\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ + ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4));\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13);\ + ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); + +#define BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal));\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4));\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4));\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4));\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4));\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ + ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); + +#define BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_loadu_pd((double const *)b11);\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + /* Pack a block of 8xk or 6xk from input buffer into packed buffer directly or after transpose based on input params @@ -566,7 +1447,8 @@ BLIS_INLINE void bli_dtrsm_small_pack double *inbuf, dim_t cs_a, double *pbuff, - dim_t p_lda + dim_t p_lda, + dim_t mr ) { //scratch registers @@ -579,9 +1461,9 @@ BLIS_INLINE void bli_dtrsm_small_pack if(side=='L'||side=='l') { - /*Left case is 8xk*/ - if(trans) - { + /*Left case is 8xk*/ + if(trans) + { /* ------------- ------------- | | | | | @@ -591,117 +1473,117 @@ BLIS_INLINE void bli_dtrsm_small_pack | | | | | ------------- ------------- */ - for(dim_t x = 0; x < size; x += D_MR) - { - ymm0 = _mm256_loadu_pd((double const *)(inbuf)); - ymm10 = _mm256_loadu_pd((double const *)(inbuf + 4)); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); - ymm11 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); - ymm12 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3)); - ymm13 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 3)); + for(dim_t x = 0; x < size; x += mr) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + 4)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); + ymm12 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3)); + ymm13 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(pbuff), ymm6); - _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); - _mm256_storeu_pd((double *)(pbuff + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(pbuff + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(pbuff), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda*3), ymm9); - ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); - ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm6); - _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm7); - _mm256_storeu_pd((double *)(pbuff + p_lda * 6), ymm8); - _mm256_storeu_pd((double *)(pbuff + p_lda * 7), ymm9); + _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda * 7), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4)); - ymm10 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 + 4)); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); - ymm11 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5 + 4)); - ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6)); - ymm12 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7)); - ymm13 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7 + 4)); + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6)); + ymm12 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7)); + ymm13 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7 + 4)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(pbuff + 4), ymm6); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda), ymm7); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(pbuff + 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda*3), ymm9); - ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); - ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 4), ymm6); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 5), ymm7); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 6), ymm8); - _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 7), ymm9); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 7), ymm9); - inbuf += D_MR; - pbuff += D_MR*D_MR; - } - }else - { - //Expected multiples of 4 - p_lda = 8; - for(dim_t x = 0; x < size; x++) - { - ymm0 = _mm256_loadu_pd((double const *)(inbuf)); - _mm256_storeu_pd((double *)(pbuff), ymm0); - ymm1 = _mm256_loadu_pd((double const *)(inbuf + 4)); - _mm256_storeu_pd((double *)(pbuff + 4), ymm1); - inbuf+=cs_a; - pbuff+=p_lda; - } - } + inbuf += mr; + pbuff += mr*mr; + } + }else + { + //Expected multiples of 4 + p_lda = 8; + for(dim_t x = 0; x < size; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + _mm256_storeu_pd((double *)(pbuff), ymm0); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + 4)); + _mm256_storeu_pd((double *)(pbuff + 4), ymm1); + inbuf+=cs_a; + pbuff+=p_lda; + } + } }else if(side=='R'||side=='r') { - if(trans) - { - /* - ------------------ ---------- - | | | | | | - | 4x4 | 4x4 | | 4x4 |4x2 | - ------------- ==> ------------- - | | | | | | - | 2x4 | 2x4 | | 2x4 |2x2 | - ------------------- ------------- - */ - for(dim_t x=0; x ------------- + | | | | | | + | 2x4 | 2x4 | | 2x4 |2x2 | + ------------------- ------------- + */ + for(dim_t x=0; x>2); i++) { ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 0 )); @@ -830,7 +1712,7 @@ BLIS_INLINE void dtrsm_small_pack_diag_element __m256d ymm0, ymm1, ymm2, ymm3; __m256d ymm4, ymm5; double ones = 1.0; - bool is_eight = (size==D_MR) ? 1 : 0; + bool is_eight = (size==8) ? 1 : 0; ymm4 = ymm5 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { @@ -879,20 +1761,20 @@ BLIS_INLINE void dtrsm_small_pack_diag_element _mm_storeu_pd((double *)(d11_pack + 4), _mm256_extractf128_pd(ymm5,0)); } } - + /* * Kernels Table */ trsmsmall_ker_ft ker_fps[8] = { - bli_dtrsm_small_AlXB, - bli_dtrsm_small_AltXB, - bli_dtrsm_small_AuXB, - bli_dtrsm_small_AutXB, - bli_dtrsm_small_XAlB, - bli_dtrsm_small_XAltB, - bli_dtrsm_small_XAuB, - bli_dtrsm_small_XAutB + bli_dtrsm_small_AutXB_AlXB, + bli_dtrsm_small_AltXB_AuXB, + bli_dtrsm_small_AltXB_AuXB, + bli_dtrsm_small_AutXB_AlXB, + bli_dtrsm_small_XAutB_XAlB, + bli_dtrsm_small_XAltB_XAuB, + bli_dtrsm_small_XAltB_XAuB, + bli_dtrsm_small_XAutB_XAlB }; /* @@ -930,7 +1812,7 @@ err_t bli_trsm_small /* ToDo: Temporary threshold condition for trsm single thread. * It will be updated with arch based threshold function which reads * tunned thresholds for all 64 (datatype,side,uplo,transa,unit,) trsm - combinations. We arrived to this condition based on performance + combinations. We arrived to this condition based on performance comparsion with only available native path */ if(m > 1000 || n > 1000) { @@ -985,10415 +1867,23 @@ err_t bli_trsm_small return err; }; -/* TRSM for the case AX = alpha * B, Double precision - * A is lower-triangular, no-transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn - - b01---> - * ***************** - ** * * * * * - * * * * * * * - * * *b01* * * * - * * * * * * * -a10 ****** b11 ***************** - | * * * | * * * * * - | * * * | * * * * * - | *a10*a11* | *b11* * * * - v * * * v * * * * * - *********** ***************** - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - **************** ***************** - a11---> -*/ - -BLIS_INLINE err_t bli_dtrsm_small_AlXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - //pointers that point to blocks for GEMM and TRSM - double *a10, *a11, *b01, *b11; - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; - - __m128d xmm5; - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_MR * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR - a. Load and pack A (a10 block), the size of packing 8x6 to 8x (m-8) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of D_NR - */ - for(i = 0;(i+D_MR-1) < m; i += D_MR) //loop along 'M' dimension - { - a10 = L + (i); //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - - dim_t p_lda = D_MR; // packed leading dimension - - /* - Pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by D_MR for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', i, 0, a10, cs_a, D_A_pack, p_lda); - - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_MR); - - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every D_NR columns of B01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i ; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) - - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + 4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); - - ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); - ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); - ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); - ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); - ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); - //b11 transpose end - - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 - 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 - where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - a11 += cs_a; - - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - - a11 += cs_a; - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - - a11 += cs_a; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - _mm256_storeu_pd((double *)(b11 + 4), ymm0); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); //store B11[7][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); - ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); - ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); - } - - dim_t n_rem = n-j; - if(n_rem >= 4) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - a11 += cs_a; - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - a11 += cs_a; - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 +7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - a11 += cs_a; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); - - n_rem -=4; - j +=4; - - } - - if(n_rem) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - a11 += cs_a; - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - a11 += cs_a; - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 +7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - a11 += cs_a; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); - } - else if(2 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - } - else if(1 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - } - } - } - - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (8x6) - above holds for reminder cases too. - */ - - dim_t m_rem = m-i; - //implementation for reamainder rows(when 'M' is not a multiple of D_MR) - if(m_rem>=4) - { - a10 = L + (i); //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - double *ptr_a10_dup = D_A_pack; - double *ptr_a11_dup = a11; - - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < i;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + cs_a * x)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); - - //Pick one element each column and create a 4 element vector and store - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = ptr_a11_dup; //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM operation to be done - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //b11 transpose end - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - a11 += cs_a; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - } - dim_t n_rem = n-j; - if(n_rem >= 4) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - a11 += cs_a; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - n_rem -= 4; - j += 4; - } - if(n_rem) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - a11 += cs_a; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } - } - m_rem -=4; - i +=4; - } - - if(m_rem) - { - a10 = L + (i); //pointer to block of A to be used for GEMM - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_rem) // Repetative A blocks will be 3*3 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j +=4; - } - - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - } - } - } - else if(2 == m_rem) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - - dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j +=4; - } - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - - dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - - dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); - - dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - } - - } - m_rem -=2; - i+=2; - } - else if(1 == m_rem) // Repetative A blocks will be 1*1 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j+=4; - } - - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - - dtrsm_AlXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - } - } - m_rem -=1; - i+=1; - } - } - - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, &local_mem_buf_A_s); - } - - return BLIS_SUCCESS; -} - -/* TRSM for the Left Upper case AX = alpha * B, Double precision - * A is Left side, upper-triangular, transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn - a10 ----> b11---> - *********** ***************** - * * * * *b01*b11* * * - **a10 * * a11 b11 * * * * * - ********* | | ***************** - *a11* * | | * * * * * - * * * | | * * * * * - ****** v v ***************** - * * * * * * * - * * * * * * * - * * ***************** - * - a11---> -*/ -BLIS_INLINE err_t bli_dtrsm_small_AutXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; - - __m128d xmm5; - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_MR * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR - a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of D_NR - */ - for(i = 0;(i+D_MR-1) < m; i += D_MR) //loop along 'M' dimension - { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); - dim_t p_lda = D_MR; // packed leading dimension - - /* - Load, tranpose and pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by D_MR for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda); - - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_MR); - - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every D_NR rows of b01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - dim_t temp = n - D_NR + 1; - for(j = 0; j < temp; j += D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) - - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - - ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - - ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 - 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 - where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - ////extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); - - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - a11 += 1; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - a11 += 1; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - a11 += 1; - - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - - a11 += 1; - - //extract a55 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - - a11 += 1; - - //extract a66 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - - a11 += 1; - - //extract a77 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); - ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - - a11 += 1; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - _mm256_storeu_pd((double *)(b11 + 4), ymm0); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); //store B11[7][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ///unpack high/// - ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); //store B11[5][0-3] - } - - dim_t n_rem = n-j; - if(n_rem >= 4) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i ; //number of times GEMM to be performed(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - - b01 += 1; //move to next row of B - a10 += D_MR; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //(ROw4): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //extract a55 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - - //(ROw5): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); - - a11 += 1; - - //extract a66 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw6): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - //extract a77 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - //(ROw7): FMA operations - ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] - - n_rem -=4; - j +=4; - - } - if(n_rem) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - b01 += 1; //move to next row of B - a10 += D_MR; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - b01 += 1; //move to next row of B - a10 += D_MR; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - b01 += 1; //move to next row of B - a10 += D_MR; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - //extract a44 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //(ROw4): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); - ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - - //extract a55 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - - //(ROw5): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); - - a11 += 1; - - //extract a66 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - - //(ROw6): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); - ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - //extract a77 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); - - a11 += 1; - //(ROw7): FMA operations - ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - } - else if(2 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - } - else if(1 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - } - } - - } - //======================M remainder cases================================ - dim_t m_rem = m-i; - if(m_rem>=4) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); - double *ptr_a10_dup = D_A_pack; - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < i;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - a10 += 4; - ptr_a10_dup += 4*4; - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM operation to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - //ymm16; - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - //ymm16; - - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - //b11 transpose end - - ////extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw1): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); - - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - a11 += 1; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - a11 += 1; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - a11 += 1; - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - } - dim_t n_rem = n-j; - if(n_rem >= 4) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - - a11 += 1; - - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - - a11 += 1; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - - a11 += 1; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - n_rem -= 4; - j += 4; - } - if(n_rem) - { - a10 = D_A_pack; - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - ////extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - - a11 += 1; - - //(ROw1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - - a11 += 1; - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); - - a11 += 1; - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw5): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(3 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_rem) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } - - } - m_rem -=4; - i +=4; - } - - if(m_rem) - { - a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_rem) // Repetative A blocks will be 3*3 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); - n_rem -= 4; - j +=4; - } - - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - } - } - } - else if(2 == m_rem) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j +=4; - } - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); - - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - } - - } - m_rem -=2; - i+=2; - } - else if(1 == m_rem) // Repetative A blocks will be 1*1 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x=0;x= 4)) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); - n_rem -= 4; - j+=4; - } - - if(n_rem) - { - a10 = D_A_pack; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM - b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_rem) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_rem) - { - ///GEMM code begins/// - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += 4; //pointer math to calculate next block of A for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - - dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); - } - } - m_rem -=1; - i+=1; - } - } - - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, &local_mem_buf_A_s); - } - return BLIS_SUCCESS; -} - -/* TRSM for the case AX = alpha * B, Double precision - * A is lower-triangular, transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn -*/ -BLIS_INLINE err_t bli_dtrsm_small_AltXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - //pointers that point to blocks for GEMM and TRSM - double *a10, *a11, *b01, *b11; - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; - - __m128d xmm5; - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( ( D_MR * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if(required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/D_MR in steps of D_MR - a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-D_MR) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n rows of B in steps of D_NR - */ - for(i = (m - D_MR); (i + 1) > 0; i -= D_MR) - { - a10 = L + (i*cs_a) + i + D_MR; //pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - - // Do transpose for a10 & store in D_A_pack - //ptr_a10_dup = D_A_pack; - - dim_t p_lda = D_MR; // packed leading dimension - /* - Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by D_MR for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', (m-i-D_MR), 1, a10, cs_a, D_A_pack,p_lda); - - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_MR); - - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every D_NR rows of b01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) - { - a10 = D_A_pack; - b01 = B + (j * cs_b) + i + D_MR; //pointer to block of B to be used for GEMM - b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - D_MR); - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) - - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); - - ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); - ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); - ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); - - ////unpackhigh//// - ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); - - //rearrange high elements - ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); - ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); - - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 - 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 - where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); - - ///unpack high/// - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); - ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); - ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); - - } - - dim_t n_remainder = j + D_NR; - if(n_remainder >= 4) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + ((n_remainder - 4)* cs_b) + i + D_MR; - b11 = B + ((n_remainder - 4)* cs_b) + i; - - k_iter = (m - i - D_MR); - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7)); - - //(ROw7): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6)); - - //(ROw6): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - //(ROw5): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4)); - - //(ROw4): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] - n_remainder -=4; - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + i + D_MR; - b11 = B + i; - - k_iter = (m - i - D_MR) ; - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7)); - - //(ROw7): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6)); - - //(ROw6): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - //(ROw5): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4)); - - //(ROw4): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] - } - } - }// End of multiples of D_MR blocks in m-dimension - - // Repetative A blocks will be 4*4 - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - i = m_remainder - 4; - a10 = L + (i*cs_a) + i + 4; //pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-i+4;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - a10 += 4; - ptr_a10_dup += 4*4; - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - n_remainder = n_remainder - 4; - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + i + 4; - b11 = B + i; - - k_iter = (m - i - 4); - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } - } - m_remainder -= 4; - } - - if(m_remainder) - { - a10 = L + m_remainder; - - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_remainder) // Repetative A blocks will be 3*3 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - a10 += 4; - ptr_a10_dup += 4*4; - } - - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - } - else if(2 == m_remainder) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - a10 += 4; - ptr_a10_dup += 4*4; - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - - } - else if(1 == m_remainder) // Repetative A blocks will be 1*1 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x+=4) - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - - a10 += 4; - ptr_a10_dup += 4*4; - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - - dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - //register to hold alpha - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - } - } - - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm,&local_mem_buf_A_s); - } - return BLIS_SUCCESS; -} - -/* - * TRSM for the case AX = alpha * B, Double precision - * A is upper-triangular, non-transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn -*/ -BLIS_INLINE err_t bli_dtrsm_small_AuXB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - //pointers that point to blocks for GEMM and TRSM - double *a10, *a11, *b01, *b11; - //double *ptr_a10_dup; - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16, ymm17, ymm18, ymm19; - __m256d ymm20; - - __m128d xmm5; - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_MR * m * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - /* - Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of D_MR - a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) - First there will be no GEMM and no packing of a10 because it is only TRSM - b. Using packed a10 block and b01 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B - d. Repeat b,c for n row of B in steps of D_NR - */ - for(i = (m - D_MR); (i + 1) > 0; i -= D_MR) - { - a10 = L + (i + D_MR)*cs_a + i; //pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - - // Do transpose for a10 & store in D_A_pack - //ptr_a10_dup = D_A_pack; //ptr_a11_dup = a11; - dim_t p_lda = D_MR; // packed leading dimension - - /* - Pack current A block (a10) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a10 block size will be increasing by D_MR for every next itteration - untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all n rows of B matrix - */ - bli_dtrsm_small_pack('L', (m-i-D_MR), 0, a10, cs_a, D_A_pack, p_lda); - - /* - Pack 8 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_MR); - - /* - a. Perform GEMM using a10, b01. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along n dimension for every D_NR rows of b01 where - packed A buffer is reused in computing all n rows of B. - d. Same approch is used in remaining fringe cases. - */ - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - b01 = B + (j*cs_b) + i + D_MR; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - D_MR); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a10 and b01 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_8x6(a10,b01,cs_b,p_lda,k_iter) - - /* - Load b11 of size 6x8 and multiply with alpha - Add the GEMM output and perform inregister transose of b11 - to peform TRSM operation. - */ - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm12); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm13); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm14); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm15); - - ymm13 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm15 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm15 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *4 + 4)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *5 + 4)); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); - ymm18 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm17 = _mm256_permute2f128_pd(ymm18,ymm16,0x20); - ymm19 = _mm256_permute2f128_pd(ymm18,ymm16,0x31); - ymm20 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm18 = _mm256_permute2f128_pd(ymm20,ymm16,0x20); - ymm20 = _mm256_permute2f128_pd(ymm20,ymm16,0x31); - - /* - Compute 8x6 TRSM block by using GEMM block output in register - a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 - 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 - where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by - other registers - b. Towards the end do in regiser transpose of TRSM output and store in b11 - */ - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); - ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); - ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); - ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm3 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); - - ///unpack high/// - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm3 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm1); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm2); - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm3); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm17, ymm18); - ymm3 = _mm256_unpacklo_pd(ymm19, ymm20); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); - - ///unpack high/// - ymm17 = _mm256_unpackhi_pd(ymm17, ymm18); - ymm18 = _mm256_unpackhi_pd(ymm19, ymm20); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm17, ymm18, 0x20); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4 + 4), ymm0); - _mm256_storeu_pd((double *)(b11 + cs_b * 5 + 4), ymm1); - - } - - dim_t n_remainder = j + D_NR; - if(n_remainder >= 4) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + ((n_remainder - 4)* cs_b) + i + D_MR; - b11 = B + ((n_remainder - 4)* cs_b) + i; - - k_iter = (m - i - D_MR); - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); - - //(ROw7): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); - - //(ROw6): FMA operations - ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - - //(ROw5): FMA operations - ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); - ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - - //(ROw4): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); - ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - - //(ROw3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - ymm7 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - - //(ROw2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*cs_a)); - - //(ROw2): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); - n_remainder -=4; - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + i + D_MR; - b11 = B + i; - - k_iter = (m - i - D_MR); - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); - - b01 += 1; //move to next row of B - a10 += p_lda; - - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); - ymm5 = _mm256_broadcast_sd((double const *)(&ones)); - ymm6 = _mm256_broadcast_sd((double const *)(&ones)); - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - } - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); - - //perform mul operation - ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); - - //(ROw7): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 + 7*cs_a)); - ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 7*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 7*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 7*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 7*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 7*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); - - //perform mul operation - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(ROw6): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5 + 6*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 6*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 6*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 6*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 6*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); - - //perform mul operation - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(ROw5): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4 + 5*cs_a)); - ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 5*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 5*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 5*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); - - //perform mul operation - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(ROw4): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 4*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 4*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); - } - } - }// End of multiples of D_MR blocks in m-dimension - - // Repetative A blocks will be 4*4 - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - i = m_remainder - 4; - a10 = L + (i + 4)*cs_a + i; //pointer to block of A to be used for GEMM - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - double *ptr_a11_dup = a11; - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-i-4;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); - - //Pick one element each column and create a 4 element vector and store - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = ptr_a11_dup; //pointer to block of A to be used for TRSM - b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm16 = _mm256_broadcast_sd((double const *)(&ones)); - - ////unpacklow//// - ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - - //rearrange high elements - ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ///unpack high/// - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] - } - - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM - - k_iter = (m - i - 4); //number of times GEMM to be performed - - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - n_remainder = n_remainder - 4; - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)() n = 3 - { - a10 = D_A_pack; - a11 = L + (i*cs_a) + i; - b01 = B + i + 4; - b11 = B + i; - - k_iter = (m - i - 4); - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_broadcast_sd((double const *)(&ones)); - ymm2 = _mm256_broadcast_sd((double const *)(&ones)); - ymm3 = _mm256_broadcast_sd((double const *)(&ones)); - } - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //extract a33 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //perform mul operation - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); - - //extract a22 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(ROw3): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 + 3*cs_a)); - ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); - - //perform mul operation - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); - - //extract a11 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 + 2*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); - ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); - - //perform mul operation - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); - - //extract a00 - ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); - - //(ROw2): FMA operations - ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); - - //perform mul operation - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(3 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - } - else if(2 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - } - else if(1 == n_remainder) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - } - } - m_remainder -= 4; - } - - if(m_remainder) - { - a10 = L + m_remainder*cs_a; - - // Do transpose for a10 & store in D_A_pack - double *ptr_a10_dup = D_A_pack; - if(3 == m_remainder) // Repetative A blocks will be 3*3 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2 + 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm2, 1)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1 + 2)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm1, 1)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm0, 1)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - } - else if(2 == m_remainder) // Repetative A blocks will be 2*2 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - xmm5 = _mm256_extractf128_pd(ymm2, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - xmm5 = _mm256_extractf128_pd(ymm1, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 0)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); - - xmm5 = _mm256_extractf128_pd(ymm0, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - - } - else if(1 == m_remainder) // Repetative A blocks will be 1*1 - { - dim_t p_lda = 4; // packed leading dimension - for(dim_t x =0;x < m-m_remainder;x++) - { - ymm0 = _mm256_loadu_pd((double const *)(a10 + x*cs_a)); - _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); - } - //cols - for(j = (n - D_NR); (j + 1) > 0; j -= D_NR) //loop along 'N' dimension - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); - ymm4 = _mm256_fmadd_pd(ymm2, ymm0, ymm4); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) - - dtrsm_AuXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); - } - dim_t n_remainder = j + D_NR; - if((n_remainder >= 4)) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM - b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); - n_remainder -= 4; - } - if(n_remainder) - { - a10 = D_A_pack; - a11 = L; //pointer to block of A to be used for TRSM - b01 = B + m_remainder; //pointer to block of B to be used for GEMM - b11 = B; //pointer to block of B to be used for TRSM - - k_iter = (m - m_remainder); //number of times GEMM to be performed - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - #endif - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - if(3 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); - ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); - } - else if(2 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ///implement TRSM/// - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); - - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); - - dtrsm_AuXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); - } - else if(1 == n_remainder) - { - ///GEMM code begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ymm0 = _mm256_loadu_pd((double const *)(a10)); - - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); - ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); - - b01 += 1; //move to next row of B - a10 += p_lda; - } - - //register to hold alpha - ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ///implement TRSM/// - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); - dtrsm_AuXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); - } - } - } - } - - if ((required_packing_A == 1) && - bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, &local_mem_buf_A_s); - } - return BLIS_SUCCESS; -} - /*implements TRSM for the case XA = alpha * B + *A is lower triangular, non-unit diagonal/unit diagonal, transpose + *dimensions: X:mxn A:nxn B: mxn + * + * b11---> a01 ----> + ***************** *********** + *b01*b11* * * * * * * +b11 * * * * * **a01 * * a11 + | ***************** ********* | + | * * * * * *a11* * | + | * * * * * * * * | + v ***************** ****** v + * * * * * * * + * * * * * * * + ***************** * * + * + *implements TRSM for the case XA = alpha * B *A is upper triangular, non-unit diagonal/unit diagonal, no transpose *dimensions: X:mxn A:nxn B: mxn * @@ -11412,7 +1902,7 @@ b11 * * * * * **a01 * * a11 */ -BLIS_INLINE err_t bli_dtrsm_small_XAuB +BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB ( obj_t* AlphaObj, obj_t* a, @@ -11423,2863 +1913,22 @@ BLIS_INLINE err_t bli_dtrsm_small_XAuB { dim_t m = bli_obj_length(b); //number of rows dim_t n = bli_obj_width(b); //number of columns + dim_t d_mr = 8,d_nr = 6; - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_NR * n * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A } - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - - __m128d xmm5; - - /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of D_NR - a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) - First there will be no GEMM and no packing of a01 because it is only TRSM - b. Using packed a01 block and b10 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of D_MR - */ - - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + else { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - dim_t p_lda = j; // packed leading dimension - - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by D_NR for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', j, 0, a01, cs_a, D_A_pack, p_lda); - - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_NR); - - /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every D_MR columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. - */ - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 of size 8x6 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4)); - - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4)); - - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); - - ///implement TRSM/// - - /* - Compute 6x8 TRSM block by using GEMM block output in register - a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 - 5. ymm11, ymm12 6. ymm13,ymm14 - b. Towards the end TRSM output will be stored back into b11 - */ - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm4, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm4, ymm14); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm6, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm6, ymm14); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm8, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm8, ymm14); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm10, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm10, ymm14); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - ymm14 = _mm256_fnmadd_pd(ymm1, ymm12, ymm14); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a )); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a )); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a )); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a )); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); - ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -= 1; - i += 1; - } + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A } - - dim_t n_remainder = n - j; - - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (6x8) - above holds for reminder cases too. - */ - - if(n_remainder >= 4) - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = j/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 3 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a )); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a )); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); - - m_remainder -= 1; - i += 1; - } - j += 4; - n_remainder -= 4; - } - - if(n_remainder == 3) - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = j/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b*2), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - - m_remainder -= 1; - i += 1; - } - j += 3; - n_remainder -= 3; - } - else if(n_remainder == 2) - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = j/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - - m_remainder -= 1; - i += 1; - } - j += 2; - n_remainder -= 2; - } - else if(n_remainder == 1) - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = j; // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = j/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ///implement TRSM/// - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - } - - dim_t m_remainder = m - i; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - - m_remainder -= 4; - i += 4; - } - - if(m_remainder == 3) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); - - m_remainder -= 3; - i += 3; - } - else if(m_remainder == 2) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - - m_remainder -= 2; - i += 2; - } - else if(m_remainder == 1) - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm6 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01); - - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); - - m_remainder -= 1; - i += 1; - } - j += 1; - n_remainder -= 1; - } - - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } - return BLIS_SUCCESS; -} - -/*implements TRSM for the case XA = alpha * B - *A is lower triangular, non-unit diagonal/unit diagonal, transpose - *dimensions: X:mxn A:nxn B: mxn - * - * b11---> a01 ----> - ***************** *********** - *b01*b11* * * * * * * -b11 * * * * * **a01 * * a11 - | ***************** ********* | - | * * * * * *a11* * | - | * * * * * * * * | - v ***************** ****** v - * * * * * * * - * * * * * * * - ***************** * * - * - -*/ - -BLIS_INLINE err_t bli_dtrsm_small_XAltB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B dim_t i, j, k; //loop variablse @@ -14298,7 +1947,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); + double d11_pack[d_mr] __attribute__((aligned(64))); rntm_t rntm; bli_rntm_init_from_global( &rntm ); @@ -14310,19 +1959,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); - if( (D_NR * n * sizeof(double)) > buffer_size) + if( (d_nr * n * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; - + if (required_packing_A == 1) { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; } //ymm scratch reginsters @@ -14334,65 +1983,66 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB __m128d xmm5; /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of D_NR + Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) First there will be no GEMM and no packing of a01 because it is only TRSM b. Using packed a01 block and b10 block perform GEMM operation c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of D_MR + d. Repeat b for m cols of B in steps of d_mr */ - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' direction { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + + //double *ptr_a10_dup = D_A_pack; dim_t p_lda = j; // packed leading dimension // perform copy of A to packed buffer D_A_pack - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by D_NR for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda); + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 6x(n-6) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_dtrsm_small_pack('R', j, 1, a01, cs_a, D_A_pack, p_lda,d_nr); - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ + /* + Pack 6 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_NR); + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_dtrsm_small_pack('R', j, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); + } /* a. Perform GEMM using a01, b10. b. Perform TRSM on a11, b11 c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every D_MR columns of B10 where + along m dimension for every d_mr columns of B10 where packed A buffer is reused in computing all m cols of B. d. Same approach is used in remaining fringe cases. */ - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - /*Fill zeros into ymm registers used in gemm accumulations */ BLIS_SET_YMM_REG_ZEROS @@ -14401,7 +2051,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB For first itteration there will be no GEMM operation where k_iter are zero */ - BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) /* Load b11 of size 8x6 and multiply with alpha @@ -14409,43 +2059,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB and peform TRSM operation. */ - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha -= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4)); - - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4)); - - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); + BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -14467,27 +2081,27 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); ymm12 = _mm256_fnmadd_pd(ymm1, ymm4, ymm12); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); ymm14 = _mm256_fnmadd_pd(ymm1, ymm4, ymm14); @@ -14501,22 +2115,22 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); ymm12 = _mm256_fnmadd_pd(ymm1, ymm6, ymm12); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); ymm14 = _mm256_fnmadd_pd(ymm1, ymm6, ymm14); @@ -14530,17 +2144,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); ymm12 = _mm256_fnmadd_pd(ymm1, ymm8, ymm12); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); ymm14 = _mm256_fnmadd_pd(ymm1, ymm8, ymm14); @@ -14554,12 +2168,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); ymm12 = _mm256_fnmadd_pd(ymm1, ymm10, ymm12); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); ymm14 = _mm256_fnmadd_pd(ymm1, ymm10, ymm14); @@ -14573,7 +2187,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); ymm14 = _mm256_fnmadd_pd(ymm1, ymm12, ymm14); @@ -14599,67 +2213,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -14671,19 +2238,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -14694,16 +2261,16 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -14714,13 +2281,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); @@ -14731,10 +2298,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); @@ -14745,7 +2312,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); @@ -14764,67 +2331,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder == 3) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -14836,19 +2356,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -14859,16 +2379,16 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -14879,13 +2399,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); @@ -14896,10 +2416,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); @@ -14910,7 +2430,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); @@ -14918,13 +2438,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_loadu_pd((double const *)b11); ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); + ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); @@ -14941,67 +2461,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 2) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -15013,19 +2486,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -15036,16 +2509,16 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -15056,13 +2529,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); @@ -15073,10 +2546,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); @@ -15087,7 +2560,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); @@ -15118,67 +2591,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 1) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -15190,19 +2616,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm3, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm3, ymm13); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -15213,16 +2639,16 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm5, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm5, ymm13); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -15233,13 +2659,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm7, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm7, ymm13); ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); @@ -15250,10 +2676,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm9, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm9, ymm13); ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); @@ -15264,7 +2690,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); //(Row 5): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); ymm13 = _mm256_fnmadd_pd(ymm1, ymm11, ymm13); ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); @@ -15304,70 +2730,115 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(n_remainder >= 4) { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = j; // packed leading dimension // perform copy of A to packed buffer D_A_pack - for(dim_t x =0;x < p_lda;x+=D_NR) + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 3 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 3 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ rs_a*3 + 3)); + } ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); @@ -15382,86 +2853,22 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 + BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -15475,17 +2882,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); @@ -15499,12 +2906,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); @@ -15518,7 +2925,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); @@ -15540,39 +2947,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha @@ -15598,13 +2983,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -15615,10 +3000,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -15629,7 +3014,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); @@ -15646,39 +3031,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder == 3) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha @@ -15693,10 +3056,11 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 ///implement TRSM/// + //extract a00 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); @@ -15705,13 +3069,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -15722,10 +3086,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -15736,7 +3100,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); @@ -15765,39 +3129,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 2) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -15824,13 +3166,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -15841,10 +3183,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -15855,7 +3197,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); @@ -15882,39 +3224,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 1) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -15940,13 +3260,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -15957,10 +3277,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -15971,24 +3291,24 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); + ymm0 = _mm256_loadu_pd((double const *)b11); ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); m_remainder -= 1; i += 1; @@ -15999,69 +3319,109 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(n_remainder == 3) { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = j; // packed leading dimension // perform copy of A to packed buffer D_A_pack - for(dim_t x =0;x < p_lda;x+=D_NR) + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + } ymm3 = _mm256_broadcast_sd((double const *)&ones); ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); @@ -16077,53 +3437,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -16157,12 +3484,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); @@ -16176,7 +3503,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); @@ -16196,35 +3523,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -16246,10 +3555,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -16260,7 +3569,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -16276,48 +3585,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder == 3) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16329,10 +3609,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -16343,25 +3623,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); + BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b) m_remainder -= 3; i += 3; @@ -16369,47 +3636,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 2) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16421,10 +3660,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -16435,23 +3674,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); + BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b) m_remainder -= 2; i += 2; @@ -16459,46 +3687,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 1) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16510,10 +3711,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -16524,21 +3725,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b) m_remainder -= 1; i += 1; @@ -16548,68 +3740,103 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB } else if(n_remainder == 2) { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = j; // packed leading dimension // perform copy of A to packed buffer D_A_pack - for(dim_t x =0;x < p_lda;x+=D_NR) + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + } ymm2 = _mm256_broadcast_sd((double const *)&ones); ymm3 = _mm256_broadcast_sd((double const *)&ones); @@ -16626,48 +3853,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -16695,7 +3894,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); @@ -16713,7 +3912,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM @@ -16723,21 +3922,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -16757,7 +3942,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); @@ -16772,7 +3957,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder == 3) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM @@ -16782,31 +3967,9 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16818,22 +3981,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); + BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b) m_remainder -= 3; i += 3; @@ -16841,7 +3994,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 2) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM @@ -16851,30 +4004,9 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a00 @@ -16885,20 +4017,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); + BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b) m_remainder -= 2; i += 2; @@ -16906,7 +4030,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 1) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM @@ -16916,29 +4040,9 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm5 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16950,18 +4054,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 1):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b) m_remainder -= 1; i += 1; @@ -16971,60 +4069,82 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB } else if(n_remainder == 1) { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a01 = L + j*rs_a; //pointer to block of A to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = j; // packed leading dimension // perform copy of A to packed buffer D_A_pack - for(dim_t x =0;x < p_lda;x+=D_NR) + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = p_lda/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); @@ -17049,37 +4169,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction + for(i = 0; (i+d_mr-1) < m; i += d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j; //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - #endif - ymm3 = _mm256_setzero_pd(); ymm4 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -17105,7 +4207,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM @@ -17114,18 +4216,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -17147,7 +4238,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB if(m_remainder == 3) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM @@ -17156,25 +4247,9 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -17182,12 +4257,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); + BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b) m_remainder -= 3; i += 3; @@ -17195,7 +4265,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 2) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM @@ -17204,24 +4274,9 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -17229,11 +4284,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); + BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b) m_remainder -= 2; i += 2; @@ -17241,7 +4292,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB else if(m_remainder == 1) { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM b10 = B + i; //pointer to block of B to be used in GEMM b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM @@ -17250,23 +4301,9 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm6 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -17274,9 +4311,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01); - - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); + BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b) m_remainder -= 1; i += 1; @@ -17294,2870 +4329,6 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB return BLIS_SUCCESS; } -/*implements TRSM for the case XA = alpha * B - *A is lower triangular, non-unit diagonal/unit diagonal, no transpose - *dimensions: X:mxn A:nxn B: mxn - * - * <---b11 <---a11 - ***************** * - *b01*b11* * * * * - ^ * * * * * ^ * * - | ***************** | ******* - | * * * * * | * * * - | * * * * * a01* * * -b10 ***************** ************* - * * * * * * * * * - * * * * * * * * * - ***************** ******************* - -*/ -BLIS_INLINE err_t bli_dtrsm_small_XAlB -( - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl -) -{ - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - - double ones = 1.0; - bool is_unitdiag = bli_obj_has_unit_diag(a); - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double* restrict L = a->buffer; //pointer to matrix A - double* restrict B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - - gint_t required_packing_A = 1; - mem_t local_mem_buf_A_s = {0}; - double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); - rntm_t rntm; - - bli_rntm_init_from_global( &rntm ); - bli_rntm_set_num_threads_only( 1, &rntm ); - bli_membrk_rntm_set_membrk( &rntm ); - - siz_t buffer_size = bli_pool_block_size( - bli_membrk_pool( - bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), - bli_rntm_membrk(&rntm))); - - if( (D_NR * n * sizeof(double)) > buffer_size) - return BLIS_NOT_YET_IMPLEMENTED; - - if (required_packing_A == 1) - { - // Get the buffer from the pool. - bli_membrk_acquire_m(&rntm, - buffer_size, - BLIS_BITVAL_BUFFER_FOR_A_BLOCK, - &local_mem_buf_A_s); - if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; - D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); - if(NULL==D_A_pack) return BLIS_NULL_POINTER; - } - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - - __m128d xmm5; - - /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of D_NR - a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) - First there will be no GEMM and no packing of a01 because it is only TRSM - b. Using packed a01 block and b10 block perform GEMM operation - c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of D_MR - */ - - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - - dim_t p_lda = (n-j-D_NR); // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by D_NR for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', p_lda, 0, a01, cs_a, D_A_pack, p_lda); - - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_NR); - - /* - a. Perform GEMM using a01, b10. - b. Perform TRSM on a11, b11 - c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every D_MR columns of B10 where - packed A buffer is reused in computing all m cols of B. - d. Same approach is used in remaining fringe cases. - */ - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR)//loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR);//no. of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - /*Fill zeros into ymm registers used in gemm accumulations */ - BLIS_SET_YMM_REG_ZEROS - - /* - Peform GEMM between a01 and b10 blocks - For first itteration there will be no GEMM operation - where k_iter are zero - */ - - BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) - - /* - Load b11 of size 8x6 and multiply with alpha - Add the GEMM output to b11 - and peform TRSM operation. - */ - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4)); - - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4)); - - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); - - ///implement TRSM/// - - /* - Compute 6x8 TRSM block by using GEMM block output in register - a. The 6x8 input (gemm outputs) are stored in combinations of ymm registers - 1. ymm3, ymm4 2. ymm5, ymm6 3. ymm7, ymm8, 4. ymm9, ymm10 - 5. ymm11, ymm12 6. ymm13,ymm14 - b. Towards the end TRSM output will be stored back into b11 - */ - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm14, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm14, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm14, ymm4); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm12, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm12, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm12, ymm4); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*4 + 4), ymm12); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); - } - - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (m_remainder - 4) + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (j*cs_b); - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -=4; - } - - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 3); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (j*cs_b); - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 2); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (j*cs_b); - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 1); //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (j*cs_b); - - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - - ///implement TRSM/// - - //extract a55 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); - ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - - //extract a44 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); - - //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5)); - ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm13, ymm3); - - ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4)); - ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm11, ymm3); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_pd(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_pd(ymm0, ymm13, 0x01); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*4), ymm11); - _mm256_storeu_pd((double *)(b11 + cs_b*5), ymm13); - - m_remainder -=1; - } - } - } - - dim_t n_remainder = j + D_NR; - - /* - Reminder cases starts here: - a. Similar logic and code flow used in computing full block (6x8) - above holds for reminder cases too. - */ - - if(n_remainder >= 4) - { - a01 = L + (n_remainder - 4)*cs_a + n_remainder; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = (n-n_remainder)/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 3 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); - } - - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm9); - - m_remainder -=4; - } - - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm9, 1)); - - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - xmm5 = _mm256_extractf128_pd(ymm9, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); - - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ///implement TRSM/// - - //extract a33 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); - - ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3)); - ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); - - m_remainder -=1; - } - } - n_remainder -= 4; - } - - if(n_remainder == 3) - { - a01 = L + (n_remainder - 3)*cs_a + n_remainder; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = (n-n_remainder)/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 2 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ///implement TRSM/// - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); - } - - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm7); - - m_remainder -=4; - } - - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); - - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ///implement TRSM/// - - //extract a22 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); - - ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); - - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - - //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2)); - ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); - - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); - - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - - m_remainder -=1; - } - } - n_remainder -= 3; - } - else if(n_remainder == 2) - { - a01 = L + (n_remainder - 2)*cs_a + n_remainder; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = (n-n_remainder)/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 1 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 1 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*1), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b*1), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ///implement TRSM/// - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); - } - - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - - m_remainder -=4; - } - - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); - - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ///implement TRSM/// - //extract a11 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); - ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); - - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - - //(Row 1): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1)); - ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); - - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - - m_remainder -=1; - } - } - n_remainder -= 2; - } - else if(n_remainder == 1) - { - a01 = L + (n_remainder - 1)*cs_a + n_remainder; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - - double *ptr_a10_dup = D_A_pack; - - dim_t p_lda = (n-n_remainder); // packed leading dimension - // perform copy of A to packed buffer D_A_pack - - dim_t loop_count = (n-n_remainder)/4; - - for(dim_t x =0;x < loop_count;x++) - { - ymm15 = _mm256_loadu_pd((double const *)(a01 + cs_a * 0 + x*4)); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); - } - - dim_t remainder_loop_count = p_lda - loop_count*4; - - __m128d xmm0; - if(remainder_loop_count != 0) - { - xmm0 = _mm_loadu_pd((double const *)(a01 + cs_a * 0 + loop_count*4)); - _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); - } - - ymm4 = _mm256_broadcast_sd((double const *)&ones); - if(!is_unitdiag) - { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - - ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); - - ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); - #ifdef BLIS_DISABLE_TRSM_PREINVERSION - ymm4 = ymm1; - #endif - #ifdef BLIS_ENABLE_TRSM_PREINVERSION - ymm4 = _mm256_div_pd(ymm4, ymm1); - #endif - } - _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + 4), ymm4); - } - - dim_t m_remainder = i + D_MR; - if(m_remainder >= 4) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 4) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - _mm256_storeu_pd((double *)b11, ymm3); - - m_remainder -=4; - } - - if(m_remainder) - { - if(3 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); - - m_remainder -=3; - } - else if(2 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 2) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - - m_remainder -=2; - } - else if (1 == m_remainder) - { - a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - - ymm3 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm6 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ///implement TRSM/// - //extract a00 - ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); - ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01); - - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); - - m_remainder -=1; - } - } - n_remainder -= 1; - } - - if ((required_packing_A == 1) && bli_mem_is_alloc( &local_mem_buf_A_s )) - { - bli_membrk_release(&rntm, - &local_mem_buf_A_s); - } - - return BLIS_SUCCESS; -} - /*implements TRSM for the case XA = alpha * B *A is upper triangular, non-unit diagonal/unit diagonal, transpose *dimensions: X:mxn A:nxn B: mxn @@ -20174,9 +4345,24 @@ b10 ***************** ************* * * * * * * * * * ***************** ******************* -*/ + *implements TRSM for the case XA = alpha * B + *A is lower triangular, non-unit diagonal/unit diagonal, no transpose + *dimensions: X:mxn A:nxn B: mxn + * + * <---b11 <---a11 + ***************** * + *b01*b11* * * * * + ^ * * * * * ^ * * + | ***************** | ******* + | * * * * * | * * * + | * * * * * a01* * * +b10 ***************** ************* + * * * * * * * * * + * * * * * * * * * + ***************** ******************* -BLIS_INLINE err_t bli_dtrsm_small_XAutB +*/ +BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB ( obj_t* AlphaObj, obj_t* a, @@ -20188,7 +4374,21 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB dim_t m = bli_obj_length(b); //number of rows dim_t n = bli_obj_width(b); //number of columns - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 6; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B dim_t i, j, k; //loop variablse @@ -20207,7 +4407,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB gint_t required_packing_A = 1; mem_t local_mem_buf_A_s = {0}; double *D_A_pack = NULL; - double d11_pack[D_MR] __attribute__((aligned(64))); + double d11_pack[d_mr] __attribute__((aligned(64))); rntm_t rntm; bli_rntm_init_from_global( &rntm ); @@ -20219,9 +4419,9 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), bli_rntm_membrk(&rntm))); - if( (D_NR * n * sizeof(double)) > buffer_size) + if( (d_nr * n * sizeof(double)) > buffer_size) return BLIS_NOT_YET_IMPLEMENTED; - + if (required_packing_A == 1) { // Get the buffer from the pool. @@ -20243,65 +4443,64 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB __m128d xmm5; /* - Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of D_NR + Performs solving TRSM for 6 rows at a time from 0 to n/6 in steps of d_nr a. Load and pack A (a01 block), the size of packing 6x6 to 6x (n-6) First there will be no GEMM and no packing of a01 because it is only TRSM b. Using packed a01 block and b10 block perform GEMM operation c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B - d. Repeat b for m cols of B in steps of D_MR + d. Repeat b for m cols of B in steps of d_mr */ - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction + for(j = (n-d_nr); (j+1) > 0; j -= d_nr) //loop along 'N' direction { - a01 = L + j +(j+D_NR)*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM + a01 = L + (j*rs_a) + (j+d_nr)*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (j*cs_a) + (j*rs_a); //pointer to block of A to be used for TRSM //double *ptr_a10_dup = D_A_pack; - dim_t p_lda = (n-j-D_NR); // packed leading dimension + dim_t p_lda = (n-j-d_nr); // packed leading dimension // perform copy of A to packed buffer D_A_pack - /* - Pack current A block (a01) into packed buffer memory D_A_pack - a. This a10 block is used in GEMM portion only and this - a01 block size will be increasing by D_NR for every next iteration - until it reaches 6x(n-6) which is the maximum GEMM alone block size in A - b. This packed buffer is reused to calculate all m cols of B matrix - */ - bli_dtrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, p_lda); + if(transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 6x(n-6) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_dtrsm_small_pack('R', p_lda, 1, a01, cs_a, D_A_pack, p_lda,d_nr); - /* - Pack 6 diagonal elements of A block into an array - a. This helps in utilze cache line efficiently in TRSM operation - b. store ones when input is unit diagonal - */ - dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,D_NR); + /* + Pack 6 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_nr); + } + else + { + bli_dtrsm_small_pack('R', p_lda, 0, a01, rs_a, D_A_pack, p_lda,d_nr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_nr); + } /* a. Perform GEMM using a01, b10. b. Perform TRSM on a11, b11 c. This loop GEMM+TRSM loops operates with 8x6 block size - along m dimension for every D_MR columns of B10 where + along m dimension for every d_mr columns of B10 where packed A buffer is reused in computing all m cols of B. d. Same approach is used in remaining fringe cases. */ - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + a11 = L + j*cs_a + j*rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) - - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ BLIS_SET_YMM_REG_ZEROS @@ -20312,50 +4511,15 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB where k_iter are zero */ - BLIS_DTRSM_SMALL_GEMM_6x8(a01,b10,cs_b,p_lda,k_iter) + BLIS_DTRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) /* Load b11 of size 8x6 and multiply with alpha Add the GEMM output to b11 and peform TRSM operation. */ - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*4 + 4)); - - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - ymm12 = _mm256_fmsub_pd(ymm1, ymm15, ymm12); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*5 + 4)); - - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); - ymm14 = _mm256_fmsub_pd(ymm1, ymm15, ymm14); + BLIS_PRE_DTRSM_SMALL_6x8(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20378,22 +4542,22 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB //(row 5):FMA operations //ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm14, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm14, ymm6); @@ -20410,17 +4574,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm12, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm12, ymm6); @@ -20437,12 +4601,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); @@ -20459,7 +4623,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); @@ -20498,71 +4662,24 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB _mm256_storeu_pd((double *)(b11 + cs_b*5 + 4), ymm14); } - dim_t m_remainder = i + D_MR; + dim_t m_remainder = i + d_mr; if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (m_remainder - 4) + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (m_remainder - 4) + (j+d_nr)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 4) + (j*cs_b); - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20574,16 +4691,16 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); @@ -20595,13 +4712,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); @@ -20613,10 +4730,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); @@ -20628,7 +4745,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -20660,67 +4777,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB if(3 == m_remainder) { a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 3); //pointer to block of B to be used in GEMM + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 3); //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 3) + (j*cs_b); - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20732,16 +4802,16 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); @@ -20753,13 +4823,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); @@ -20771,10 +4841,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); @@ -20786,7 +4856,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -20828,67 +4898,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB else if(2 == m_remainder) { a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 2); //pointer to block of B to be used in GEMM + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 2); //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 2) + (j*cs_b); - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20900,16 +4923,16 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); @@ -20921,13 +4944,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); @@ -20939,10 +4962,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); @@ -20954,7 +4977,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -20996,67 +5019,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB else if (1 == m_remainder) { a01 = D_A_pack; - a11 = L + (j*cs_a) + j; - b10 = B + (j+D_NR)*cs_b + (m_remainder - 1); //pointer to block of B to be used in GEMM + a11 = L + (j*cs_a) + (j*rs_a); + b10 = B + (j+d_nr)*cs_b + (m_remainder - 1); //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 1) + (j*cs_b); - k_iter = (n-j-D_NR); //number of GEMM operations to be done(in blocks of 4x4) + k_iter = (n-j-d_nr); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 4)); //A01[0][4] - ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 5)); //A01[0][5] - ymm13 = _mm256_fmadd_pd(ymm2, ymm0, ymm13); - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*4)); - ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11); - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5)); - ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + // Load b11 of size 4x6 and multiply with alpha + BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -21068,16 +5044,16 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); //(row 5):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 4*rs_a)); ymm11 = _mm256_fnmadd_pd(ymm1, ymm13, ymm11); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm13, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm13, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm13, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a)); @@ -21089,13 +5065,13 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); //(row 4):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 3*rs_a)); ymm9 = _mm256_fnmadd_pd(ymm1, ymm11, ymm9); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm11, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm11, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a)); @@ -21107,10 +5083,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); @@ -21122,7 +5098,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -21164,7 +5140,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB } } - dim_t n_remainder = j + D_NR; + dim_t n_remainder = j + d_nr; /* Reminder cases starts here: @@ -21174,70 +5150,115 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB if(n_remainder >= 4) { - a01 = L + (n_remainder - 4) + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a01 = L + (n_remainder - 4)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = (n-n_remainder); // packed leading dimension // perform copy of A to packed buffer D_A_pack - for(dim_t x =0;x < p_lda;x+=D_NR) + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 3 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 3 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 3 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+ rs_a*3 + 3)); + } ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); @@ -21252,86 +5273,22 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*3), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*5), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + BLIS_DTRSM_SMALL_GEMM_4nx8m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); //B11[4-7][0] * alpha-= ymm1 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); //B11[4-7][1] * alpha -= ymm3 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + 4)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 - ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); //B11[4-7][2] * alpha -= ymm5 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b*3 + 4)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); //B11[0-3][3] * alpha -= ymm6 - ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); //B11[4-7][3] * alpha -= ymm7 + BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -21345,12 +5302,12 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); @@ -21367,7 +5324,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); @@ -21402,43 +5359,21 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB _mm256_storeu_pd((double *)(b11 + cs_b*3 + 4), ymm10); } - dim_t m_remainder = i + D_MR; + dim_t m_remainder = i + d_mr; if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 4) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -21464,10 +5399,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); @@ -21479,7 +5414,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -21509,40 +5444,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB if(3 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 3) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -21570,10 +5482,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); @@ -21585,7 +5497,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -21625,39 +5537,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB else if(2 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 2) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -21684,10 +5574,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); @@ -21699,7 +5589,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -21737,39 +5627,17 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB else if (1 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 4)*cs_a + (n_remainder - 4)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 1) + (n_remainder - 4)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); //A01[0][3] - ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -21795,10 +5663,10 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); //(Row 3): FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 2*rs_a)); ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a)); @@ -21810,7 +5678,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -21826,19 +5694,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm9 = _mm256_blend_pd(ymm0, ymm9, 0x01); - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); - _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm9, 0)); m_remainder -=1; } @@ -21848,70 +5716,110 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB if(n_remainder == 3) { - a01 = L + (n_remainder - 3) + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a01 = L + (n_remainder - 3)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = (n-n_remainder); // packed leading dimension // perform copy of A to packed buffer D_A_pack - for(dim_t x =0;x < p_lda;x+=D_NR) + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 2 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 2 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 2 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); - ymm3 = _mm256_broadcast_sd((double const *)&ones); + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + } + else + { + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+ rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+ rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + } ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); @@ -21926,54 +5834,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b*2), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4 + cs_b*2), _MM_HINT_T0); - #endif - - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_3nx8m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -22007,7 +5881,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); @@ -22040,39 +5914,21 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB _mm256_storeu_pd((double *)(b11 + cs_b*2 + 4), ymm8); } - dim_t m_remainder = i + D_MR; + dim_t m_remainder = i + d_mr; if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 4) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -22095,7 +5951,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -22124,51 +5980,21 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB if(3 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 3) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal,b11,cs_b) ///implement TRSM/// - //extract a22 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); @@ -22177,7 +6003,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -22194,67 +6020,26 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); - _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); + BLIS_POST_DTRSM_SMALL_3N_3M(b11,cs_b) m_remainder -=3; } else if(2 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 2) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -22266,7 +6051,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -22283,64 +6068,26 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); - xmm5 = _mm256_extractf128_pd(ymm7, 0); - _mm_storeu_pd((double *)(b11 + cs_b * 2),xmm5); + BLIS_POST_DTRSM_SMALL_3N_2M(b11,cs_b) m_remainder -=2; } else if (1 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 3)*cs_a + (n_remainder - 3)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 1) + (n_remainder - 3)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); //A01[0][2] - ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); //B11[0-3][2] * alpha -= ymm4 + BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -22352,7 +6099,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); //(row 2):FMA operations - ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 1*rs_a)); ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a)); @@ -22369,16 +6116,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_pd(ymm0, ymm7, 0x01); - - _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); - _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm7, 0)); + BLIS_POST_DTRSM_SMALL_3N_1M(b11,cs_b) m_remainder -=1; } @@ -22387,68 +6125,103 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB } else if(n_remainder == 2) { - a01 = L + (n_remainder - 2) + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a01 = L + (n_remainder - 2)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = (n-n_remainder); // packed leading dimension // perform copy of A to packed buffer D_A_pack - for(dim_t x =0;x < p_lda;x+=D_NR) + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 1 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 1 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 1 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); if(!is_unitdiag) { - //broadcast diagonal elements of A11 - ymm0 = _mm256_broadcast_sd((double const *)(a11)); - ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + } ymm2 = _mm256_broadcast_sd((double const *)&ones); ymm3 = _mm256_broadcast_sd((double const *)&ones); @@ -22465,46 +6238,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + cs_b+4), _MM_HINT_T0); - #endif + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_2nx8m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -22546,35 +6293,21 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); } - dim_t m_remainder = i + D_MR; + dim_t m_remainder = i + d_mr; if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 4) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -22610,41 +6343,19 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB if(3 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 3) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -22661,58 +6372,26 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x07); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b*1 + 2)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x07); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); - _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); + BLIS_POST_DTRSM_SMALL_2N_3M(b11,cs_b) m_remainder -=3; } else if(2 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 2) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a11 @@ -22728,55 +6407,26 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x03); - xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 1)); - ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x03); - - _mm256_storeu_pd((double *)b11, ymm3); - xmm5 = _mm256_extractf128_pd(ymm5, 0); - _mm_storeu_pd((double *)(b11 + cs_b*1), xmm5); + BLIS_POST_DTRSM_SMALL_2N_2M(b11,cs_b) m_remainder -=2; } else if (1 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 2)*cs_a + (n_remainder - 2)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 1) + (n_remainder - 2)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - ymm3 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); //A01[0][1] - ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 - - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); //B11[0-3][1] * alpha-= ymm2 + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a11 @@ -22792,13 +6442,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_broadcast_sd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm0, ymm3, 0x01); - ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_pd(ymm0, ymm5, 0x01); - - _mm_storel_pd(b11 , _mm256_extractf128_pd(ymm3, 0)); - _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm5, 0)); + BLIS_POST_DTRSM_SMALL_2N_1M(b11,cs_b) m_remainder -=1; } @@ -22807,60 +6451,82 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB } else if(n_remainder == 1) { - a01 = L + (n_remainder - 1) + n_remainder*cs_a; //pointer to block of A to be used in GEMM - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a01 = L + (n_remainder - 1)*rs_a + n_remainder*cs_a; //pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM double *ptr_a10_dup = D_A_pack; dim_t p_lda = (n-n_remainder); // packed leading dimension // perform copy of A to packed buffer D_A_pack - for(dim_t x =0;x < p_lda;x+=D_NR) + if(transa) { - ymm0 = _mm256_loadu_pd((double const *)(a01)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); - ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); - ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + for(dim_t x =0;x < p_lda;x+=d_nr) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); - _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); - ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); - ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); + ymm0 = _mm256_loadu_pd((double const *)(a01 + cs_a * 4)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a * 5)); - ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); - ymm5 = _mm256_broadcast_sd((double const *)&zero); + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_broadcast_sd((double const *)&zero); - ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); - ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); - ymm1 = _mm256_broadcast_sd((double const *)&zero); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_broadcast_sd((double const *)&zero); - ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); - _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); - _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4), _mm256_extractf128_pd(ymm6,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda), _mm256_extractf128_pd(ymm7,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*2), _mm256_extractf128_pd(ymm8,0)); + _mm_storeu_pd((double *)(ptr_a10_dup + 4 + p_lda*3), _mm256_extractf128_pd(ymm9,0)); - a01 += D_NR*cs_a; - ptr_a10_dup += D_NR; + a01 += d_nr*cs_a; + ptr_a10_dup += d_nr; + } + } + else + { + dim_t loop_count = (n-n_remainder)/4; + + for(dim_t x =0;x < loop_count;x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + rs_a * 0 + x*4)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + x*4), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count*4; + + __m128d xmm0; + if(remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + rs_a * 0 + loop_count*4)); + _mm_storeu_pd((double *)(ptr_a10_dup + p_lda * 0 + loop_count*4), xmm0); + } } ymm4 = _mm256_broadcast_sd((double const *)&ones); @@ -22885,37 +6551,20 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB } _mm256_storeu_pd((double *)(d11_pack), ymm4); - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction + for(i = (m-d_mr); (i+1) > 0; i -= d_mr) //loop along 'M' direction { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM b10 = B + i + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (i) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) - #ifdef BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL - _mm_prefetch((char*)(b11 + 0), _MM_HINT_T0); - _mm_prefetch((char*)(b11 + 4), _MM_HINT_T0); - #endif + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -22935,11 +6584,11 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB _mm256_storeu_pd((double *)(b11 + 4), ymm4); } - dim_t m_remainder = i + D_MR; + dim_t m_remainder = i + d_mr; if(m_remainder >= 4) { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 4) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM @@ -22948,18 +6597,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha @@ -22981,7 +6619,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB if(3 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 3) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 3) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM @@ -22990,25 +6628,9 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm0 = _mm256_broadcast_sd((double const *)(b11+ 2)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a00 @@ -23018,16 +6640,14 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm0 = _mm256_loadu_pd((double const *)b11); ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); - _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); + BLIS_POST_DTRSM_SMALL_1N_3M(b11,cs_b) m_remainder -=3; } else if(2 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM b10 = B + (m_remainder - 2) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM b11 = B + (m_remainder - 2) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM @@ -23036,76 +6656,41 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - xmm5 = _mm_loadu_pd((double const*)(b11)); - ymm6 = _mm256_insertf128_pd(ymm0, xmm5, 0); - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a00 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_pd((double const *)b11); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x03); - - xmm5 = _mm256_extractf128_pd(ymm3, 0); - _mm_storeu_pd((double *)(b11), xmm5); + BLIS_POST_DTRSM_SMALL_1N_2M(b11,cs_b) m_remainder -=2; } else if (1 == m_remainder) { a01 = D_A_pack; - a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1); //pointer to block of A to be used for TRSM - b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM + a11 = L + (n_remainder - 1)*cs_a + (n_remainder - 1)*rs_a; //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 1) + (n_remainder)*cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 1) + (n_remainder - 1)*cs_b; //pointer to block of B to be used for TRSM k_iter = (n-n_remainder); //number of GEMM operations to be done(in blocks of 4x4) ymm3 = _mm256_setzero_pd(); ///GEMM implementation starts/// - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - //load 8x1 block of B10 - ymm0 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) - //broadcast 1st row of A01 - ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); //A01[0][0] - ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - - a01 += 1; //move to next row - b10 += cs_b; - } - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm6 = _mm256_broadcast_sd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_fmsub_pd(ymm6, ymm15, ymm3); //B11[0-3][0] * alpha -= ymm0 + BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal,b11,cs_b) ///implement TRSM/// //extract a00 ymm0 = _mm256_broadcast_sd((double const *)(d11_pack )); ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x01); - - _mm_storel_pd(b11, _mm256_extractf128_pd(ymm3, 0)); + BLIS_POST_DTRSM_SMALL_1N_1M(b11,cs_b) m_remainder -=1; } @@ -23120,4 +6705,4019 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB } return BLIS_SUCCESS; } -#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM + +/* TRSM for the case AX = alpha * B, Double precision + * A is lower-triangular, transpose, non-unit diagonal + * dimensions A: mxm X: mxn B: mxn +*/ +BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 6; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + //pointers that point to blocks for GEMM and TRSM + double *a10, *a11, *b01, *b11; + + double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + + __m128d xmm5; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; + double d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if((d_mr * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if(required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/d_mr in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-d_mr) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = (m - d_mr); (i + 1) > 0; i -= d_mr) + { + a10 = L + (i*cs_a) + (i + d_mr)*rs_a; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + //ptr_a10_dup = D_A_pack; + + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack('L', (m-i-d_mr), 1, a10, cs_a, D_A_pack,p_lda,d_mr); + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_dtrsm_small_pack('L', (m-i-d_mr), 0, a10, rs_a, D_A_pack,p_lda,d_mr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + b01 = B + (j * cs_b) + i + d_mr; //pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm15, ymm20 2. ymm14, ymm19 3. ymm13, ymm18 , 4. ymm12, ymm17 + 5. ymm11, ymm7 6. ymm10, ymm6, 7.ymm9, ymm5 8. ymm8, ymm4 + where ymm15-ymm8 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm20, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm15, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm20, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm15, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm20, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm15, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm20, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm15, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm20, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm15, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm20, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm15, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm20, ymm4); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm14, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm19, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm14, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm19, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm14, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm19, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm14, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm19, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm14, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm19, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm14, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm19, ymm4); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm13, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm18, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm13, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm18, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm13, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm18, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm13, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm18, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm13, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm18, ymm4); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm12, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm17, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm12, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm17, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm12, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm17, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm12, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm17, ymm4); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b) + } + + dim_t n_remainder = j + d_nr; + if(n_remainder >= 4) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + ((n_remainder - 4)* cs_b) + i + d_mr; + b11 = B + ((n_remainder - 4)* cs_b) + i; + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + n_remainder -=4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + d_mr; + b11 = B + i; + + k_iter = (m - i - d_mr) ; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6*cs_a + 7*rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 7*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 7*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 7*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 7*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7*rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5*cs_a + 6*rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 6*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 6*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 6*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6*rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4*cs_a + 5*rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 5*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 5*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5*rs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3*cs_a + 4*rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 4*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4*rs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + }// End of multiples of d_mr blocks in m-dimension + + // Repetative A blocks will be 4*4 + dim_t m_remainder = i + d_mr; + if(m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i*cs_a) + (i + 4)*rs_a; //pointer to block of A to be used for GEMM + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-i+4;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-i-4;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+rs_a*3 + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm7, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm7, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm7, ymm4); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm6, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm6, ymm4); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + ymm4 = _mm256_fnmadd_pd(ymm2, ymm5, ymm4); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + i + 4; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + n_remainder = n_remainder - 4; + } + + if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i*cs_a) + (i*rs_a); + b01 = B + i + 4; + b11 = B + i; + + k_iter = (m - i - 4); + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + if(3 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2*cs_a + 3*rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3*rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2*rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2*rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1*rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_remainder -= 4; + } + + a10 = L + m_remainder*rs_a; + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_remainder) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_remainder) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + + } + else if(1 == m_remainder) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x =0;x < m-m_remainder;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < m-m_remainder;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x*rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x*p_lda), ymm0); + } + } + //cols + for(j = (n - d_nr); (j + 1) > 0; j -= d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + (j* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ///GEMM code ends/// + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to store alpha value + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 6, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 6, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4)* cs_b) + m_remainder; //pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4)* cs_b); //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if(n_remainder) + { + a10 = D_A_pack; + a11 = L; //pointer to block of A to be used for TRSM + b01 = B + m_remainder; //pointer to block of B to be used for GEMM + b11 = B; //pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_remainder) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm,&local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} + +/* TRSM for the Left Upper case AX = alpha * B, Double precision + * A is Left side, upper-triangular, transpose, non-unit/unit diagonal + * dimensions A: mxm X: mxn B: mxn + a10 ----> b11---> + *********** ***************** + * * * * *b01*b11* * * + **a10 * * a11 b11 * * * * * + ********* | | ***************** + *a11* * | | * * * * * + * * * | | * * * * * + ****** v v ***************** + * * * * * * * + * * * * * * * + * * ***************** + * + a11---> + + * TRSM for the case AX = alpha * B, Double precision + * A is Left side, lower-triangular, no-transpose, non-unit/unit diagonal + * dimensions A: mxm X: mxn B: mxn + + b01---> + * ***************** + ** * * * * * + * * * * * * * + * * *b01* * * * + * * * * * * * +a10 ****** b11 ***************** + | * * * | * * * * * + | * * * | * * * * * + | *a10*a11* | *b11* * * * + v * * * v * * * * * + *********** ***************** + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + * * * * * * * * * + **************** ***************** + a11---> +*/ +BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB +( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl +) +{ + dim_t m = bli_obj_length(b); // number of rows of matrix B + dim_t n = bli_obj_width(b); // number of columns of matrix B + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8,d_nr = 6; + + // Swap rs_a & cs_a in case of non-tranpose. + if(transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; //loop variables + dim_t k_iter; //number of times GEMM to be performed + + double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha + double *L = a->buffer; //pointer to matrix A + double *B = b->buffer; //pointer to matrix B + + double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM + + double ones = 1.0; + bool is_unitdiag = bli_obj_has_unit_diag(a); + + //scratch registers + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15; + __m256d ymm16, ymm17, ymm18, ymm19; + __m256d ymm20; + + __m128d xmm5; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; + double d11_pack[d_mr] __attribute__((aligned(64))); + rntm_t rntm; + + bli_rntm_init_from_global( &rntm ); + bli_rntm_set_num_threads_only( 1, &rntm ); + bli_membrk_rntm_set_membrk( &rntm ); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ( (d_mr * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if(FALSE==bli_mem_is_alloc(&local_mem_buf_A_s)) return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if(NULL==D_A_pack) return BLIS_NULL_POINTER; + } + + /* + Performs solving TRSM for 8 colmns at a time from 0 to m/8 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for(i = 0;(i+d_mr-1) < m; i += d_mr) //loop along 'M' dimension + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + dim_t p_lda = d_mr; // packed leading dimension + + if(transa) + { + /* + Load, tranpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack('L', i, 1, a10, cs_a, D_A_pack, p_lda,d_mr); + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,d_mr); + } + else + { + bli_dtrsm_small_pack('L', i, 0, a10, rs_a, D_A_pack, p_lda,d_mr); + dtrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,d_mr); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along n dimension for every d_nr rows of b01 where + packed A buffer is reused in computing all n rows of B. + d. Same approch is used in remaining fringe cases. + */ + dim_t temp = n - d_nr + 1; + for(j = 0; j < temp; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + /* + Peform GEMM between a10 and b01 blocks + For first itteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8mx6n(a10,b01,cs_b,p_lda,k_iter) + + /* + Load b11 of size 6x8 and multiply with alpha + Add the GEMM output and perform inregister transose of b11 + to peform TRSM operation. + */ + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_6x8(b11,cs_b,AlphaVal) + + /* + Compute 8x6 TRSM block by using GEMM block output in register + a. The 8x6 input (gemm outputs) are stored in combinations of ymm registers + 1. ymm8, ymm4 2. ymm9, ymm5 3. ymm10, ymm6, 4. ymm11, ymm7 + 5. ymm12, ymm17 6. ymm13,ymm18, 7. ymm14,ymm19 8. ymm15, ymm20 + where ymm8-ymm15 holds 8x4 data and reaming 8x2 will be hold by + other registers + b. Towards the end do in regiser transpose of TRSM output and store in b11 + */ + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm8, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm4, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm8, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm4, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm8, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm4, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm8, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm4, ymm20); + + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm9, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm5, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm9, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm5, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm9, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm5, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm9, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm5, ymm20); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm10, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm6, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm10, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm6, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm10, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm6, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm10, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm6, ymm20); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + + a11 += rs_a; + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + //(ROw4): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm17 = _mm256_fnmadd_pd(ymm2, ymm7, ymm17); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm11, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm7, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm11, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm7, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm11, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm7, ymm20); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm17 = DTRSM_SMALL_DIV_OR_SCALE(ymm17, ymm1); + + a11 += rs_a; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm18 = _mm256_fnmadd_pd(ymm2, ymm17, ymm18); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm12, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm17, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm17, ymm20); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + ymm18 = DTRSM_SMALL_DIV_OR_SCALE(ymm18, ymm1); + + a11 += rs_a; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm19 = _mm256_fnmadd_pd(ymm2, ymm18, ymm19); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm13, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm18, ymm20); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + ymm19 = DTRSM_SMALL_DIV_OR_SCALE(ymm19, ymm1); + + a11 += rs_a; + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + //(ROw7): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm20 = _mm256_fnmadd_pd(ymm2, ymm19, ymm20); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + ymm20 = DTRSM_SMALL_DIV_OR_SCALE(ymm20, ymm1); + + a11 += rs_a; + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x6_AND_STORE(b11,cs_b) + } + + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i ; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *3 + 4)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] + + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + + a11 += rs_a; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); //store B11[7][0-3] + + n_rem -=4; + j +=4; + } + + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *2 + 4)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *1 + 4)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *0 + 4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + ///implement TRSM/// + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + cs_a*4)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + //perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + cs_a*5)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + + //extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + //perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a*6)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 +cs_a*7)); + + a11 += rs_a; + + //extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + //perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + //extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a*7)); + + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + //perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); //store B11[6][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); //store B11[5][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); //store B11[4][0-3] + } + } + } + + //======================M remainder cases================================ + dim_t m_rem = m-i; + if(m_rem>=4) //implementation for reamainder rows(when 'M' is not a multiple of d_mr) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + + if(transa) + { + for(dim_t x =0;x < i;x+=p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda*p_lda; + } + } + else + { + for(dim_t x =0;x < i;x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if(!is_unitdiag) + { + if(transa) + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+cs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); + } + else + { + //broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11+rs_a*1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11+rs_a*2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11+rs_a*3 + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; + #endif + #ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); + #endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for(j = 0; (j+d_nr-1) < n; j += d_nr) //loop along 'N' dimension + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM operation to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx6n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + + ymm16 = _mm256_broadcast_sd((double const *)(&ones)); + + ////unpacklow//// + ymm7 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + //ymm16; + + //rearrange low elements + ymm4 = _mm256_permute2f128_pd(ymm7,ymm16,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm6 = _mm256_permute2f128_pd(ymm7,ymm16,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + //ymm16; + + //rearrange high elements + ymm5 = _mm256_permute2f128_pd(ymm0,ymm16,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm7 = _mm256_permute2f128_pd(ymm0,ymm16,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + //b11 transpose end + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw1): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm5 = _mm256_fnmadd_pd(ymm2, ymm4, ymm5); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm8, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm4, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm4, ymm7); + + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm1); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm6 = _mm256_fnmadd_pd(ymm2, ymm5, ymm6); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm9, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm5, ymm7); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm1); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm7 = _mm256_fnmadd_pd(ymm2, ymm6, ymm7); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm1); + + a11 += rs_a; + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///unpack high/// + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store B11[1][0-3] + } + + dim_t n_rem = n-j; + if(n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + //extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] + + n_rem -= 4; + j += 4; + } + if(n_rem) + { + a10 = D_A_pack; + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + j*cs_b; //pointer to block of B to be used for GEMM + b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + ///transpose of B11// + ///unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + //rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + //rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + //extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a*1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + + //perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a*2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + //perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a*3)); + + a11 += rs_a; + + //extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + //perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + //unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + //rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + //rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if(3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] + } + else if(2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] + } + else if(1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] + } + } + m_rem -=4; + i +=4; + } + + if(m_rem) + { + a10 = L + (i*cs_a); //pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if(3 == m_rem) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b*3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3),xmm5); + _mm_storel_pd((b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm3, 1)); + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b,is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b,is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b,is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if(2 == m_rem) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0C); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0C); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0C); + + _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) + xmm5 = _mm256_extractf128_pd(ymm3, 0); + _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j +=4; + } + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -=2; + i+=2; + } + else if(1 == m_rem) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if(transa) + { + for(dim_t x=0;x= 4)) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10,b01,cs_b,p_lda,k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); //register to hold alpha + + ///implement TRSM/// + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b *0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b *1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b *2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b *3)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); + ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); + ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); + ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); + + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm0, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm1, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm2, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm3, 0)); + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j+=4; + } + + if(n_rem) + { + a10 = D_A_pack; //pointer to block of A to be used for GEMM + a11 = L + (i*rs_a) + (i*cs_a); //pointer to block of A to be used for TRSM + b01 = B + (j*cs_b); //pointer to block of B to be used for GEMM + b11 = B + i + (j* cs_b); //pointer to block of B to be used for TRSM + + k_iter = i; //number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS + + if(3 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if(2 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if(1 == n_rem) + { + ///GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,p_lda,k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal,b11,cs_b) + + if(transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AutXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -=1; + i+=1; + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc( &local_mem_buf_A_s )) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} +#endif //BLIS_ENABLE_SMALL_MATRIX_TRSM \ No newline at end of file