From c63a078a575a15f3bd1b74d62bd3f31950f8a786 Mon Sep 17 00:00:00 2001 From: Meghana Date: Thu, 21 Nov 2019 12:31:09 +0530 Subject: [PATCH] Fixed segemntation fault in trsm_small kernels for cases XAuB, XAltB, XAlB For matrix sizes which are not multiples of 4, trsm_small kernels access memory outside the allocated buffers which causes segmentation fault. This is fixed by handling each of the corner cases separately. Change-Id: I267e69ee095a8ca3e8ce2a3ada5f48bfefcc2219 --- kernels/zen/3/bli_trsm_small.c | 3530 ++++++++++++++++++++------------ 1 file changed, 2266 insertions(+), 1264 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index af84d0588..ee4c07a49 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -931,7 +931,6 @@ a10 ****** b11 ***************** **************** ***************** a11---> */ - static err_t bli_dtrsm_small_AlXB( side_t side, obj_t* AlphaObj, @@ -4161,101 +4160,6 @@ static err_t bli_dtrsm_small_XAuB( ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -4263,137 +4167,170 @@ static err_t bli_dtrsm_small_XAuB( //load 8x4 block of B11 if(n_remainder == 3) { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 2) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 1) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - //3rd col - a11 += cs_a; - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] + //3rd col + a11 += cs_a; + ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - //4th col - a11 += cs_a; - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] + //4th col + a11 += cs_a; + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] - ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] - ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] + ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] - ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] + ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] - ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] - - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] - - ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] - - if(n_remainder == 3) - { _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) @@ -4402,17 +4339,223 @@ static err_t bli_dtrsm_small_XAuB( _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) } if(n_remainder == 2) - { + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += cs_a; + ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + + ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] 1 1 + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/1 1/1) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) } + if(n_remainder == 1) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + + a01 += 1; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + ///GEMM code ends/// + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + } } } if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) @@ -4608,205 +4751,311 @@ static err_t bli_dtrsm_small_XAuB( k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } - if(n_remainder == 2) - { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } - if(n_remainder == 1) - { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); - - ///GEMM processing stars/// - - for(k = 0; k < k_iter; k++) + if(n_remainder == 3) { - ptr_a01_dup = a01; + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + ///GEMM for previous blocks /// + ///GEMM processing stars/// - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; - a01 += 1; //move to next row of A + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + a01 += 1; //move to next row of A - a01 += 1; //move to next row of A + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + a01 += 1; //move to next row of A - a01 += 1; //move to next row of A + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + a01 += 1; //move to next row of A - a01 += 1; //move to next row of A + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + + } + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += cs_a; + ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+1));//A11[1][1] + + //3rd col + a11 += cs_a; + ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] + ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] + ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + ymm14 = _mm256_broadcast_sd((double const *)&ones); - } + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm6 = _mm256_unpacklo_pd(ymm9, ymm14); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - ///GEMM code ends/// + ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 + //extract A00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - ///implement TRSM/// + ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0] - ///read 4x4 block of A11/// + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //(Row1): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0] + ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0] - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1] + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - //2nd col - a11 += cs_a; - ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(a11+1));//A11[1][1] + //(Row2)FMA operations + ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1] - //3rd col - a11 += cs_a; - ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] + ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2] - //4th col - a11 += cs_a; - ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] - ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] - ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - - ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0] - - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(Row1): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0] - ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0] - ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0] - - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1] - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row2)FMA operations - ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1] - ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1] - - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2] - - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - //(Row3)FMA operations - ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2] - - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3] - - if(n_remainder == 3) - { _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) } - if(n_remainder == 2) + if(n_remainder == 2) { + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM for previous blocks /// + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + + } + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += cs_a; + ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] + ymm6 = _mm256_broadcast_sd((double const *)(a11+1));//A11[1][1] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + + ymm15 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract A00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0] + + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0] + + ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1] + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) } if(n_remainder == 1) { + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///GEMM for previous blocks /// + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st col + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm14 = _mm256_div_pd(ymm14, ymm4); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract A00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0] + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) } - } m_remainder -= 4; i += 4; @@ -4821,20 +5070,21 @@ static err_t bli_dtrsm_small_XAuB( b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4) - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b_offset[1])[iter]; + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ///GEMM for previous blocks /// ///load 4x4 block of b11 ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha + ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); @@ -4912,10 +5162,10 @@ static err_t bli_dtrsm_small_XAuB( ///GEMM implementation ends/// - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[x][0] -=ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[x][3] -= ymm7 ///implement TRSM/// @@ -4990,34 +5240,33 @@ static err_t bli_dtrsm_small_XAuB( ymm4 = _mm256_loadu_pd((double const *)(b11)); ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); if(m_remainder == 3) { ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); } if(m_remainder == 2) { ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30); ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30); ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30); - ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30); } if(m_remainder == 1) { ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E); ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E); ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E); - ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E); } _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) + _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3]) + + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b_offset[1])[iter] = f_temp[iter]; } if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR) { @@ -5027,149 +5276,263 @@ static err_t bli_dtrsm_small_XAuB( b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4) + + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter]; + ///GEMM for previous blocks /// + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); ///load 4x4 block of b11 if(n_remainder == 3) { ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ///GEMM implementation starts/// + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + ///GEMM implementation ends + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha + + ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); + ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); + + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30); + ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30); + ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E); + ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E); + ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)f_temp, ymm2); //(store(B11[x][2])) } if(n_remainder == 2) { ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM implementation starts/// + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + + a01 += 1; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + ///GEMM implementation ends + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha + + ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); + ///implement TRSM/// + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30); + ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E); + ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1]) } if(n_remainder == 1) { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); + ymm0 = _mm256_loadu_pd((double const *)f_temp); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) - { - ptr_a01_dup = a01; + ///GEMM implementation starts/// + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - a01 += 1; //move to next row of A + a01 += 1; //move to next row of A - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + ///GEMM implementation ends - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha + ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); + ///implement TRSM/// + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E); + } + _mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0]) } - ///GEMM implementation ends - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha - - ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); - ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); - ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); - ymm7 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); - ///implement TRSM/// - if(m_remainder == 3) - { - ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm7, ymm3, 0x08); - } - if(m_remainder == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30); - ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30); - ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30); - ymm3 = _mm256_permute2f128_pd(ymm7,ymm3,0x30); - } - if(m_remainder == 1) - { - ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E); - ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E); - ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E); - ymm3 = _mm256_blend_pd(ymm7,ymm3,0x0E); - } - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - } + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; //scalar code for TRSM dtrsm_small_XAuB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } @@ -6438,7 +6801,7 @@ static err_t bli_dtrsm_small_XAltB( for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction { - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction + for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction { a01 = L + j; //pointer to block of A to be used in GEMM a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM @@ -6691,239 +7054,180 @@ static err_t bli_dtrsm_small_XAltB( ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //subtract the calculated GEMM block from current TRSM block //load 8x4 block of B11 if(n_remainder == 3) { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + + a01 += cs_a; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + + a01 += cs_a; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 2) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 1) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + //3rd col + a11 += 1; + ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] - ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] + //(Row2)FMA operations + ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] + ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] + ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] - ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] + ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] - - ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] - - if(n_remainder == 3) - { _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) @@ -6933,15 +7237,216 @@ static err_t bli_dtrsm_small_XAltB( } if(n_remainder == 2) { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + + a01 += cs_a; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + + a01 += cs_a; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm8 = _mm256_loadu_pd((double const *)b11); + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); + ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + //2nd col + a11 += 1; + ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + + ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + //extract a00 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] + + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] + + ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] + + ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] + + ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - } + } if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + + a01 += cs_a; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) + + //broadcast 3rd row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + + a01 += cs_a; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) + + //broadcast 4th row of A01 + ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + + a01 += cs_a; //move to next row of A + + ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) + + ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + ///GEMM code ends/// + + ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] + ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] + + ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] + ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] + + ///implement TRSM/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm8 = _mm256_mul_pd(ymm8, ymm7); //B11[0-3][0] /= A11[0][0] + + ymm12 = _mm256_mul_pd(ymm12, ymm7); //B11[4-7][0] /= A11[0][0] + + _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) } } } @@ -7141,35 +7646,6 @@ static err_t bli_dtrsm_small_XAltB( k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } - if(n_remainder == 2) - { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } - if(n_remainder == 1) - { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); @@ -7177,169 +7653,313 @@ static err_t bli_dtrsm_small_XAltB( ymm7 = _mm256_setzero_pd(); - ///GEMM processing stars/// - - for(k = 0; k < k_iter; k++) + if(n_remainder == 3) { - ptr_a01_dup = a01; + ///GEMM for previous blocks /// + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + ///GEMM processing stars/// - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; - a01 += cs_a; //move to next row of A + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] + a01 += cs_a; //move to next row of A - a01 += cs_a; //move to next row of A + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + a01 += cs_a; //move to next row of A - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - } + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } - ///GEMM code ends/// + ///GEMM code ends/// - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[x][0] -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[x][2] -= ymm6 - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - //1st row - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); - ymm7 = _mm256_broadcast_sd((double const *)(a11+2)); - ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); + //1st row + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); + ymm7 = _mm256_broadcast_sd((double const *)(a11+2)); - a11 += cs_a;//move to next column + a11 += cs_a;//move to next column - //2nd row - ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); - ymm8 = _mm256_broadcast_sd((double const *)(a11+2)); - ymm11 = _mm256_broadcast_sd((double const *)(a11+3)); + //2nd row + ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); + ymm8 = _mm256_broadcast_sd((double const *)(a11+2)); - a11 += cs_a;//move to next column + a11 += cs_a;//move to next column - //3rd row - ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); - ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); + //3rd row + ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); - a11 += cs_a;//move to next column + a11 += cs_a;//move to next column - //4th row - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); + //4th row + ymm13 = _mm256_broadcast_sd((double const *)(&ones)); + ymm14 = _mm256_broadcast_sd((double const *)&ones); - ymm14 = _mm256_broadcast_sd((double const *)&ones); + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + //extract A00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0] - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0] + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //(Row1): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0] + ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0] - //(Row1): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0] - ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0] - ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0] + ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1] + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1] - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //(Row2)FMA operations + ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1] - //(Row2)FMA operations - ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1] - ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1] + ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2] - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2] - - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - //(Row3)FMA operations - ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2] - - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3] - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) } - if(n_remainder == 2) + if(n_remainder == 2) { + ///GEMM for previous blocks /// + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[x][0] -= ymm4 + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[x][1] -= ymm5 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st row + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); + ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); + + a11 += cs_a;//move to next column + + //2nd row + ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); + + a11 += cs_a;//move to next column + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + + ymm15 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract A00 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) + + ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0] + + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(Row1): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0] + + ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1] + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) } if(n_remainder == 1) { + ///GEMM for previous blocks /// + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[x][0] -= ymm4 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //1st row + ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + ymm14 = _mm256_div_pd(ymm14, ymm4); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + ymm0 = _mm256_mul_pd(ymm0, ymm14); //B11[x][0] /= A11[0][0] + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) } @@ -7357,14 +7977,21 @@ static err_t bli_dtrsm_small_XAltB( b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4) - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha + + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b_offset[1])[iter]; + + ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha ///GEMM for previous blocks /// ///load 4x4 block of b11 ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] //multiply by alpha ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha @@ -7528,34 +8155,33 @@ static err_t bli_dtrsm_small_XAltB( ymm4 = _mm256_loadu_pd((double const *)(b11)); ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); if(m_remainder == 3) { ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); } if(m_remainder == 2) { ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30); ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30); ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30); - ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30); } if(m_remainder == 1) { ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E); ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E); ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E); - ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E); } _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) + _mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3]) + + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b_offset[1])[iter] = f_temp[iter]; } if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR) { @@ -7565,34 +8191,119 @@ static err_t bli_dtrsm_small_XAltB( b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4) - ///GEMM for previous blocks /// - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } - if(n_remainder == 2) - { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } - if(n_remainder == 1) - { - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - } + double f_temp[4]; + int iter; + + for(iter = 0; iter < m_remainder; iter++) + f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter]; + ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); + ///GEMM for previous blocks /// + + if(n_remainder == 3) + { + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + ///GEMM implementation ends + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha + + ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); + ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); + ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); + ///implement TRSM/// + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08); + ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08); + ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30); + ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30); + ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E); + ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E); + ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E); + } + _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) + _mm256_storeu_pd((double *)(f_temp), ymm2); //(store(B11[x][2])) + } + if(n_remainder == 2) + { + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] ///GEMM implementation starts/// @@ -7609,54 +8320,38 @@ static err_t bli_dtrsm_small_XAltB( //broadcast 1st row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) //broadcast 2nd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) //braodcast 3rd row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) //broadcast 4th row of A01 ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] a01 += cs_a; //move to next row of A ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM @@ -7669,47 +8364,97 @@ static err_t bli_dtrsm_small_XAltB( ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); - ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); - ymm7 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); ///implement TRSM/// if(m_remainder == 3) { ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08); ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm7, ymm3, 0x08); } if(m_remainder == 2) { ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30); ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30); - ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30); - ymm3 = _mm256_permute2f128_pd(ymm7,ymm3,0x30); } if(m_remainder == 1) { ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E); ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E); - ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E); - ymm3 = _mm256_blend_pd(ymm7,ymm3,0x0E); } - if(n_remainder == 3) - { _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) + _mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1]) + } + if(n_remainder == 1) + { + ///load 4x4 block of b11 + ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///GEMM implementation starts/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) + + //broadcast 2nd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) + + //braodcast 3rd row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) + + //broadcast 4th row of A01 + ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] + + a01 += cs_a; //move to next row of A + + ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM + } + ///GEMM implementation ends + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha + + ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); + ///implement TRSM/// + if(m_remainder == 3) + { + ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08); + } + if(m_remainder == 2) + { + ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30); + } + if(m_remainder == 1) + { + ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E); + } + _mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0]) } + for(iter = 0; iter < m_remainder; iter++) + (b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter]; //scalar code for TRSM - dtrsm_small_XAltB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); + dtrsm_small_XAltB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); } } return BLIS_SUCCESS; @@ -9220,101 +9965,6 @@ static err_t bli_dtrsm_small_XAlB( ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); @@ -9322,137 +9972,171 @@ static err_t bli_dtrsm_small_XAlB( //load 8x4 block of B11 if(n_remainder == 3) { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] - } - if(n_remainder == 2) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - } - if(n_remainder == 1) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] - } - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - ///implement TRSM/// + ///implement TRSM/// - ///read 4x4 block of A11/// + ///read 4x4 block of A11/// - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] + //1st col + ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0] - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] + //2nd col + a11 += 1; + ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + //3rd col + a11 += 1; + ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + //compute reciprocals of L(i,i) and broadcast in registers + ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm11 = _mm256_mul_pd(ymm11, ymm0); + ymm11 = _mm256_mul_pd(ymm11, ymm0); - ymm15 = _mm256_mul_pd(ymm15, ymm0); + ymm15 = _mm256_mul_pd(ymm15, ymm0); - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - //(row 3):FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm10 = _mm256_mul_pd(ymm10, ymm0); + ymm10 = _mm256_mul_pd(ymm10, ymm0); - ymm14 = _mm256_mul_pd(ymm14, ymm0); + ymm14 = _mm256_mul_pd(ymm14, ymm0); - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + //extract a11 + ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - //(Row 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); + //(Row 2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm9 = _mm256_mul_pd(ymm9, ymm0); + ymm9 = _mm256_mul_pd(ymm9, ymm0); - ymm13 = _mm256_mul_pd(ymm13, ymm0); + ymm13 = _mm256_mul_pd(ymm13, ymm0); - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - - if(n_remainder == 3) - { _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) @@ -9461,16 +10145,223 @@ static err_t bli_dtrsm_small_XAlB( _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } if(n_remainder == 2) - { + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] + ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] + ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] + + ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + + ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //3rd col + a11 += 2; + ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] + + //4th col + a11 += 1; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] + + ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] + + //extract a33 + ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm0); + + ymm15 = _mm256_mul_pd(ymm15, ymm0); + + //extract a22 + ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(row 3):FMA operations + ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); + + ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); + + ymm10 = _mm256_mul_pd(ymm10, ymm0); + + ymm14 = _mm256_mul_pd(ymm14, ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } + } if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) + { + ///GEMM implementation begins/// + + for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations + { + ptr_a01_dup = a01; + + //broadcast 1st row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + //load 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) + + //broadcast 3rd row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + //load next 8x2 block of B10 + ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) + ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) + ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) + ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) + + ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) + + //broadcast 4th row of A01 + ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) + + ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] + ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] + + ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] + ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + + //4th col + a11 += 3; + ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] + + //compute reciprocals of L(i,i) and broadcast in registers + ymm7 = _mm256_div_pd(ymm7, ymm6); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) + + ymm11 = _mm256_mul_pd(ymm11, ymm7); + + ymm15 = _mm256_mul_pd(ymm15, ymm7); + + _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) + _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) } } } @@ -9670,205 +10561,316 @@ static err_t bli_dtrsm_small_XAlB( k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - } - if(n_remainder == 2) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - } - if(n_remainder == 1) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - } - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha ymm4 = _mm256_setzero_pd(); ymm5 = _mm256_setzero_pd(); ymm6 = _mm256_setzero_pd(); ymm7 = _mm256_setzero_pd(); - - ///GEMM processing stars/// - - for(k = 0; k < k_iter; k++) + ///GEMM for previous blocks /// + if(n_remainder == 3) { - ptr_a01_dup = a01; + ///load 4x4 block of b11 + ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + ///GEMM processing stars/// - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; - a01 += 1; //move to next row of A + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + //broadcast 1st row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + a01 += 1; //move to next row of A - a01 += 1; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + //broadcast 2nd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + a01 += 1; //move to next row of A - a01 += 1; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + //braodcast 3rd row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + a01 += 1; //move to next row of A - a01 += 1; //move to next row of A + ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + //broadcast 4th row of A01 + ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[x][1] -= ymm5 + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + //2nd col + a11 += cs_a; + ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] + ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] + ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] + + //3rd col + a11 += cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm4 = _mm256_unpacklo_pd(ymm14, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); + + //extract a11 + ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) + + //(ROW 2): FMA operations + ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); + + ymm1 = _mm256_mul_pd(ymm1, ymm15); - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - - } - - ///GEMM code ends/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - - ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - ymm3 = _mm256_mul_pd(ymm3, ymm15); - - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - ymm2 = _mm256_mul_pd(ymm2, ymm15); - - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - ymm1 = _mm256_mul_pd(ymm1, ymm15); - - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - ymm0 = _mm256_mul_pd(ymm0, ymm15); - - if(n_remainder == 3) - { _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) + _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) } - if(n_remainder == 2) + if(n_remainder == 2) { + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + } + + ///GEMM code ends/// + + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[x][2] -= ymm6 + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //3rd col + a11 += 2 * cs_a; + ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] + ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] + + //4th col + a11 += cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] + + ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] + ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) + + ymm3 = _mm256_mul_pd(ymm3, ymm15); + + //extract a22 + ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) + ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) + + //(Row 3): FMA operations + ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); + + ymm2 = _mm256_mul_pd(ymm2, ymm15); + _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) } if(n_remainder == 1) { + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ///GEMM processing stars/// + + for(k = 0; k < k_iter; k++) + { + ptr_a01_dup = a01; + + //load 4x4 bblock of b10 + ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] + ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] + ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] + ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] + + //broadcast 1st row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) + + //broadcast 2nd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) + + //braodcast 3rd row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) + + //broadcast 4th row of A01 + ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] + + a01 += 1; //move to next row of A + + ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) + + b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM + a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM + + } + + ///GEMM code ends/// + + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[x][3] -= ymm7 + + ///implement TRSM/// + + ///read 4x4 block of A11/// + + //4th col + a11 += 3 * cs_a; + ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] + + ymm14 = _mm256_broadcast_sd((double const *)&ones); + + //compute reciprocals of A(i,i) and broadcast in registers + ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] + + //extract a33 + ymm3 = _mm256_mul_pd(ymm3, ymm14); + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) } - } m_remainder -= 4; i -= 4;