From ec907c3f4b86fcf3f72985a27f3e3b65bc34eb89 Mon Sep 17 00:00:00 2001 From: Meghana Date: Mon, 27 May 2019 15:36:44 +0530 Subject: [PATCH] Defined small matrix thresholds for TRSM for various cases for NAPLES and ROME Updated copyright information for kernels/zen/bli_trsm_small.c file Removed separate kernels for zen2 architecture Instead added threshold conditions in zen kernels both for ROME and NAPLES Change-Id: Ifd715731741d649b6ad16b123a86dbd6665d97e5 --- kernels/zen2/3/bli_trsm_small.c | 25023 ------------------------------ 1 file changed, 25023 deletions(-) delete mode 100644 kernels/zen2/3/bli_trsm_small.c diff --git a/kernels/zen2/3/bli_trsm_small.c b/kernels/zen2/3/bli_trsm_small.c deleted file mode 100644 index 346df3eca..000000000 --- a/kernels/zen2/3/bli_trsm_small.c +++ /dev/null @@ -1,25023 +0,0 @@ -/* - -BLIS -An object-based framework for developing high-performance BLAS-like -libraries. - -Copyright (C) 2018, Advanced Micro Devices, Inc. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: -- Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. -- Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. -- Neither the name of The University of Texas at Austin nor the names -of its contributors may be used to endorse or promote products -derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*/ - -#include "blis.h" -#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM -#include "immintrin.h" -#define GEMM_BLK_V1 8 //Block size to perform gemm and apply trsm -#define GEMM_ACCUM_A 1 //Peform B1=B1-(B0*A0) operation instead of B1'=(B0*A0) and then B1=B1-B1' -#define OPT_CACHE_BLOCKING_L1 1 //Perform trsm block-wise in blocks of GEMM_BLK_V1 instead of all columns of B together. -#define REARRANGE_SHFL 0 //Rearrange operations using blend or shuffle -#define BLI_AlXB_M_SP 16 -//#define BLI_AlXB_M_DP 16 -#define BLI_XAltB_N_SP 128 -#define BLI_AutXB_M_SP 64 -#define BLI_AutXB_N_SP 128 -#define max(a,b) a>b?a:b -#define min(a,b) a 0; k--) - { - double lkk_inv = 1.0/A[k+k*lda]; - for(i = M-1; i+1 > 0; i--) - { - B[i+k*ldb] *= lkk_inv; - for(j = k-1; j+1 > 0; j--) - { - B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; - } - } - } -return BLIS_SUCCESS; -} - -/* TRSM scalar code for the case XA = alpha * B - * A is lower-triangular, unit-diagonal, no transpose - *Dimensions: X:mxn A:nxn B:mxn - */ -static err_t dtrsm_small_XAlB_unitDiag( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb -) -{ - - dim_t i, j, k; - - for(i = 0 ; i < M; i++) - for(j = 0; j < N; j++) - B[i+j*ldb] *= alpha; - - for(k = N-1; k+1 > 0; k--) - { - for(i = M-1; i+1 > 0; i--) - { - for(j = k-1; j+1 > 0; j--) - { - B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; - } - } - } -return BLIS_SUCCESS; -} - -/* TRSM scalar code for the case XA = alpha * B - *A is upper-triangular, non-unit-diagonal, A is transposed - * Dimensions: X:mxn A:nxn B:mxn - */ -static err_t dtrsm_small_XAutB ( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb -) -{ - - dim_t i, j, k; - - for(i = 0; i < M; i++) - for(j = 0; j < N; j++) - B[i+j*ldb] *=alpha; - - for(k = N-1; k+1 > 0; k--) - { - double lkk_inv = 1.0/A[k+k*lda]; - for(i = M-1; i+1 > 0; i--) - { - B[i+k*ldb] *= lkk_inv; - for(j = k-1; j+1 > 0; j--) - { - B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; - } - } - } -return BLIS_SUCCESS; -} - -/* TRSM scalar code for the case XA = alpha * B - * A is upper-triangular, unit-diagonal, A has to be transposed - * Dimensions: X:mxn A:nxn B:mxn - */ -static err_t dtrsm_small_XAutB_unitDiag( - double *A, - double *B, - double alpha, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb -) -{ - - dim_t i, j, k; - - for(i = 0; i< M; i++) - for(j = 0; j< N; j++) - B[i+j*ldb] *= alpha; - - for(i = M-1; i+1 > 0; i--) - { - for(j = N-1; j+1 > 0; j--) - { - for(k = j-1; k+1 > 0; k--) - { - B[i+k*ldb] -= B[i+j*ldb] * A[k+j*lda]; - - } - } - } -return BLIS_SUCCESS; -} - -/* TRSM scalar code for the case XA = alpha * B - * A is lower-triangular, non-unit-diagonal, A has to be transposed - * Dimensions: X:mxn A:nxn B:mxn - */ -static err_t dtrsm_small_XAltB ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb -) -{ - - dim_t i, j, k; - - for(k = 0; k < N; k++) - { - double lkk_inv = 1.0/A[k+k*lda]; - for(i = 0; i < M; i++) - { - B[i+k*ldb] *= lkk_inv; - for(j = k+1; j < N; j++) - { - B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; - } - } - } -return BLIS_SUCCESS; -} - -/* TRSM scalar code for XA = alpha * B - * A is lower-triangular, unit-diagonal, A has to be transposed - * Dimensions: X:mxn A:nxn B:mxn - */ -static err_t dtrsm_small_XAltB_unitDiag( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb -) -{ - - dim_t i, j, k; - - for(k = 0; k < N; k++) - { - for(i = 0; i < M; i++) - { - for(j = k+1; j < N; j++) - { - B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda]; - } - } - } -return BLIS_SUCCESS; -} - -/* TRSM scalar code for the case XA = alpha * B - * A is upper-triangular, unit-diagonal, no transpose - * Dimensions: X:mxn A:nxn B:mxn - */ -static err_t dtrsm_small_XAuB_unitDiag ( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb -) -{ - - dim_t i, j, k; - - for(k = 0; k < N; k++) - { - for(i = 0; i < M; i++) - { - for(j = k+1; j < N; j++) - { - B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda]; - } - } - } -return BLIS_SUCCESS; -} - -/* TRSM for the case AX = alpha * B, Double precision - * A is lower-triangular, no-transpose, non-unit diagonal - * dimensions A: mxm X: mxn B: mxn - - b01---> - * ***************** - ** * * * * * - * * * * * * * - * * *b01* * * * - * * * * * * * -a10 ****** b11 ***************** - | * * * | * * * * * - | * * * | * * * * * - | *a10*a11* | *b11* * * * - v * * * v * * * * * - *********** ***************** - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - **************** ***************** - a11---> -*/ - -static err_t bli_dtrsm_small_AlXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - - dim_t D_MR = 4; //size of block along 'M' dimpension - dim_t D_NR = 8; //size of block along 'N' dimension - - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - if(max(m,n) > 90) - return BLIS_NOT_YET_IMPLEMENTED; - - dim_t m_remainder = m % D_MR; //number of remainder rows - dim_t n_remainder = n % D_NR; //number of remainder columns - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - double *ptr_b01_dup; - - double ones = 1.0; - - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - - - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code begins/// - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1] - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3] - - ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm6 = _mm256_unpacklo_pd(ymm3, ymm4); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] - - ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2] - - //extract a00 - ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - - //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] - ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] - - //extract a11 - ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4] - - ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] - ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] - - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A110[][0] 1/A11[2][2] 1/A11[2][2] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] - - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] - - //perform mul operation - ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /= A11[2][2] - ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2] - - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_permute_pd(ymm0, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11);//1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(ROw2): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2] - - ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6] - - //perform mul operation - ymm11 = _mm256_mul_pd(ymm11, ymm1); //B11[0-3][3] /= A11[3][3] - ymm15 = _mm256_mul_pd(ymm15, ymm1); //B11[0-3][7] /= A11[3][3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] - } - - if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code Begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - - b01 += 1; //move to next row of B01 - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - - ymm0 = _mm256_broadcast_sd((double const *)&ones); - - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1] - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3] - - ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm6 = _mm256_unpacklo_pd(ymm3, ymm4); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] - - ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - //extract a00 - ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - - //(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0] - ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0] - - //extract a11 - ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= B11[0-3][0]*A11[3][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= B11[0-3][4]*A11[3][4] - - ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1] - ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1] - - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] - - a11 += cs_a; - - //extract a22 - ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] - - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] - - //perform mul operation - ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /=A11[2][2] - ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2] - - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] - - a11 += cs_a; - - //extract a33 - ymm1 = _mm256_permute_pd(ymm0, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(ROw2): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[0-3][3] -= A11[3][2]*B11[0-3][2] - - ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[0-3][7] -= A11[3][2]*B11[0-3][6] - - //perform mul operation - ymm11 = _mm256_mul_pd(ymm11, ymm1); //B11[0-3][3] /= A11[3][3] - ymm15 = _mm256_mul_pd(ymm15, ymm1); //B11[0-3][7] /= A11[3][3] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b * 7)); //load B11[0-3][7] - //determine correct values to store - if(m_remainder == 3) - { - ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08); - ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08); - ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); - ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); - ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); - } - if(m_remainder == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30); - ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30); - ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); - ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); - ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); - } - if(m_remainder == 1) - { - ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E); - ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E); - ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E); - ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E); - ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); - } - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store(B11[0-3][7]) - - } - } - - if((n & 4)) //implementation for remainder columns(when 'N' is a multiple of 4) - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B01[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B01[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B01[0-3][3] *alpha -= ymm7 - - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] -/* - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg); - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg); -*/ - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] -/* - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg); - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg); -*/ - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] - - - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] - - //extract diag a33 from a - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //looop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7 - - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] -/* - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg); - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg); -*/ - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] -/* - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg); - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg); -*/ - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]* B11[0][0-3] - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] - - - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]* B11[1][0-3] - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] - - //extract diag a33 from a - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]* B11[2][0-3] - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - - if(m_remainder == 3) - { - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); - } - if(m_remainder == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); - } - if(m_remainder == 1) - { - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); - } - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - } - - n_remainder -= 4; - j += 4; - - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - if(n_remainder == 3) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 2) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 1) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const*)&ones); - } - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - - ///GEMM code ends/// - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7 - - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - //compute reciprocals of L(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] -/* - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg); - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg); -*/ - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] -/* - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg); - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg); -*/ - //extract a00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0] - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0] - - //extract diag a11 from a - ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1] - - - //extract diag a22 from a - ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2] - - //extract diag a33 from a - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - - k_iter = i / D_MR; //number of times GEMM operations to be performed - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ///GEMM for previously calculated values /// - - - //load 4x4 block from b11 - if(n_remainder == 3) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 2) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 1) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[0][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[1][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[2][3] - - b01 += 1; //move to next row of B - - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[3][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] * alpha -= ymm5 - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] * alpha -= ymm6 - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] * alpha -= ymm7 - - ///implement TRSM/// - //determine correct values to store - if(m_remainder == 3) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - } - if(m_remainder == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); - ymm3 = _mm256_permute2f128_pd(ymm11, ymm3, 0x30); - - } - if(m_remainder == 1) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - } - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } - - ///scalar code for trsm without alpha/// - dtrsm_small_AlXB(a11, b11, m_remainder, n_remainder, cs_a, cs_b); - } - } - return BLIS_SUCCESS; -} - -/* TRSM for the case AX = alpha * B, Double precision - * A is lower-triangular, no-transpose, unit diagonal - * dimensions A: mxm X: mxn B: mxn - - b01---> - * ***************** - ** * * * * * - * * * * * * * - * * *b01* * * * - * * * * * * * -a10 ****** b11 ***************** - | * * * | * * * * * - | * * * | * * * * * - | *a10*a11* | *b11* * * * - v * * * v * * * * * - *********** ***************** - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - * * * * * * * * * - **************** ***************** - a11---> -*/ - -static err_t bli_dtrsm_small_AlXB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - - dim_t D_MR = 4; //size of block along 'M' dimpension - dim_t D_NR = 8; //size of block along 'N' dimension - - dim_t m = bli_obj_length(b); // number of rows of matrix B - dim_t n = bli_obj_width(b); // number of columns of matrix B - - if(max(m,n) > 90) - return BLIS_NOT_YET_IMPLEMENTED; - - dim_t m_remainder = m % D_MR; //number of remainder rows - dim_t n_remainder = n % D_NR; //number of remainder columns - - dim_t cs_a = bli_obj_col_stride(a); // column stride of A - dim_t cs_b = bli_obj_col_stride(b); // column stride of B - - dim_t i, j, k; //loop variables - dim_t k_iter; //number of times GEMM to be performed - - double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM - double *ptr_b01_dup; - - double ones = 1.0; - - //scratch registers - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - - - for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code begins/// - - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - - b01 += 1; //mobe to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM - } - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///transpose of B11// - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1] - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4] - - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] - - a11 += cs_a; - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] - - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] - - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] - - a11 += cs_a; - - //(ROw2): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2] - - ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3] - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3] - } - - if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4) - - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm11 = _mm256_setzero_pd(); - ymm12 = _mm256_setzero_pd(); - ymm13 = _mm256_setzero_pd(); - ymm14 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); - - ///GEMM code Begins/// - for(k = 0; k< k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] ) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7] - - b01 += 1; //move to next row of B01 - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4] - ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5] - ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6] - ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7] - - b01 += 1; //move to next row of B - - ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3]) - ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3]) - ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3]) - ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - - ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0] - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1] - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2] - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3] - ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4] - ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5] - ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6] - ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7] - - ///implement TRSM/// - - ///unpacklow/// - ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5] - ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7] - - //rearrange low elements - ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3] - ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1] - ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3] - - //rearrange high elements - ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3] - ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3] - - //broadcast diagonal elements of A11 - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1] - ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0] - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0] - - a11 += cs_a; - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= B11[0-3][0]*A11[3][0] - - ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= B11[0-3][4]*A11[3][4] - - ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1] - - a11 += cs_a; - - //(ROw2): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1] - ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1] - - ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5] - ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5] - - ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2] - - a11 += cs_a; - - //(ROw2): FMA operations - ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[0-3][3] -= A11[3][2]*B11[0-3][2] - - ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[0-3][7] -= A11[3][2]*B11[0-3][6] - - //unpacklow// - ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2] - ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4] - ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6] - - ///unpack high/// - ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7] - ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5] - ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7] - - ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3] - ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b * 7)); //load B11[0-3][7] - //determine correct values to store - if(m_remainder == 3) - { - ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08); - ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08); - ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08); - ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08); - ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08); - } - if(m_remainder == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30); - ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30); - ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30); - ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30); - ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30); - } - if(m_remainder == 1) - { - ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E); - ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E); - ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E); - ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E); - ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E); - } - - _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5]) - _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6]) - _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store(B11[0-3][7]) - - } - } - - if((n & 4)) //implementation for remainder columns(when 'N' is a multiple of 4) - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4) - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B01[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B01[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B01[0-3][3] *alpha -= ymm7 - - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3]) - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //looop for number of GEMM operations - { - ptr_b01_dup = b01; - - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7 - - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]* B11[0][0-3] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]* B11[1][0-3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]* B11[2][0-3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //load 4x4 block from b11 - ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3] - - //determine correct values to store - - if(m_remainder == 3) - { - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08); - } - if(m_remainder == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30); - ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30); - ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30); - ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30); - } - if(m_remainder == 1) - { - ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E); - ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E); - ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E); - ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E); - } - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) - - } - - n_remainder -= 4; - j += 4; - - } - - if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR) - { - for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4) - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value - - ///GEMM for previously calculated values /// - - //load 4x4 block from b11 - if(n_remainder == 3) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 2) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 1) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const*)&ones); - } - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3] - - b01 += 1; - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - } - - ///GEMM code ends/// - - ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4 - ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5 - ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6 - ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7 - - ///implement TRSM/// - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3] - - ////unpacklow//// - ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1] - ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3] - - //rearrange low elements - ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3] - ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3] - - ////unpackhigh//// - ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1] - ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3] - - //rearrange high elements - ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3] - ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3] - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3] - ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3] - ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3] - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3] - ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3] - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3] - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2] - ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2] - - //rearrange low elements - ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - - ////unpackhigh//// - ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3] - - ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3] - - //rearrange high elements - ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } - - } - if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR) - { - a10 = L +i; //pointer to block of A to be used for GEMM - a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM - b01 = B + j*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM - - - k_iter = i / D_MR; //number of times GEMM operations to be performed - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value - - ///GEMM for previously calculated values /// - - - //load 4x4 block from b11 - if(n_remainder == 3) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 2) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 1) - { - ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_broadcast_sd((double const *)&ones); - ymm2 = _mm256_broadcast_sd((double const *)&ones); - ymm3 = _mm256_broadcast_sd((double const *)&ones); - } - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_b01_dup = b01; - ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3] - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[0][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[1][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[2][3] - - b01 += 1; //move to next row of B - - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2]) - - ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[3][3] - - b01 += 1; //move to next row of B - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3]) - - a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM - b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM - - } - - ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] * alpha -= ymm4 - ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] * alpha -= ymm5 - ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] * alpha -= ymm6 - ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] * alpha -= ymm7 - - ///implement TRSM/// - //determine correct values to store - if(m_remainder == 3) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08); - - } - if(m_remainder == 2) - { - ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30); - ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30); - ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30); - ymm3 = _mm256_permute2f128_pd(ymm11, ymm3, 0x30); - - } - if(m_remainder == 1) - { - ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E); - ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E); - ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E); - ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E); - } - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0]) - } - - ///scalar code for trsm without alpha/// - dtrsm_small_AlXB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b); - } - } - return BLIS_SUCCESS; -} - - -/*implements TRSM for the case XA = alpha * B - *A is upper triangular, non-unit diagonal, no transpose - *dimensions: X:mxn A:nxn B: mxn - */ - -/* b11---> a01 ----> - ***************** *********** - *b01*b11* * * * * * * -b11 * * * * * **a01 * * a11 - | ***************** ********* | - | * * * * * *a11* * | - | * * * * * * * * | - v ***************** ****** v - * * * * * * * - * * * * * * * - ***************** * * - * - -*/ -static err_t bli_dtrsm_small_XAuB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - dim_t D_MR = 8; //block dimension along the rows - dim_t D_NR = 4; //block dimension along the columns - - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - dim_t m_remainder = m % D_MR; //number of corner rows - dim_t n_remainder = n % D_NR; //number of corner columns - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - if(max(m,n)>150 && (m/n) < 22) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - dim_t cs_b_offset[2]; //pre-calculated strides - - double ones = 1.0; - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - double *ptr_a01_dup; - - cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; - cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - - //4th col - a11 += cs_a; - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] - - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] - - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] - - ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] - - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] - - ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] - - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] - - ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - a01 = L + j*cs_a; //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///load 4x4 block of b11 - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - - //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 2) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 1) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - - //4th col - a11 += cs_a; - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] - - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] - - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] - - ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] - - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] - - ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] - - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] - - ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - } - } - } - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = 0; (j+D_NR-1) a01 ----> - ***************** *********** - *b01*b11* * * * * * * -b11 * * * * * **a01 * * a11 - | ***************** ********* | - | * * * * * *a11* * | - | * * * * * * * * | - v ***************** ****** v - * * * * * * * - * * * * * * * - ***************** * * - * - -*/ - -static err_t bli_dtrsm_small_XAuB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - dim_t D_MR = 8; //block dimension along the rows - dim_t D_NR = 4; //block dimension along the columns - - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - dim_t m_remainder = m % D_MR; //number of corner rows - dim_t n_remainder = n % D_NR; //number of corner columns - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - if((max(m,n)>180) && (m/n)<22) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - dim_t cs_b_offset[2]; //pre-calculated strides - - double ones = 1.0; - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - double *ptr_a01_dup; - - cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; - cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - - //4th col - a11 += cs_a; - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - a01 = L + j*cs_a; //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///load 4x4 block of b11 - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - - //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 2) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 1) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2] - - //4th col - a11 += cs_a; - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3] - - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] - - ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - } - } - } - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = 0; (j+D_NR-1) a01 ----> - ***************** *********** - *b01*b11* * * * * * * -b11 * * * * * **a01 * * a11 - | ***************** ********* | - | * * * * * *a11* * | - | * * * * * * * * | - v ***************** ****** v - * * * * * * * - * * * * * * * - ***************** * * - * - -*/ -static err_t bli_dtrsm_small_XAltB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - dim_t D_MR = 8; //block dimension along the rows - dim_t D_NR = 4; //block dimension along the columns - - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - dim_t m_remainder = m % D_MR; //number of corner rows - dim_t n_remainder = n % D_NR; //number of corner columns - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - if(max(m,n) > 120) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - dim_t cs_b_offset[2]; //pre-calculated strides - - double ones = 1.0; - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - double *ptr_a01_dup; - - cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; - cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction - { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] - - ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] - - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] - - ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] - - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] - - ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - a01 = L + j; //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///load 4x4 block of b11 - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - - //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 2) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 1) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1] - - ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1] - - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2] - - ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2] - - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3] - - ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3] - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - } - } - } - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = 0; (j+D_NR-1) a01 ----> - ***************** *********** - *b01*b11* * * * * * * -b11 * * * * * **a01 * * a11 - | ***************** ********* | - | * * * * * *a11* * | - | * * * * * * * * | - v ***************** ****** v - * * * * * * * - * * * * * * * - ***************** * * - * - -*/ -static err_t bli_dtrsm_small_XAltB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - dim_t D_MR = 8; //block dimension along the rows - dim_t D_NR = 4; //block dimension along the columns - - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - dim_t m_remainder = m % D_MR; //number of corner rows - dim_t n_remainder = n % D_NR; //number of corner columns - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - if(max(m,n) > 120) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - dim_t cs_b_offset[2]; //pre-calculated strides - - double ones = 1.0; - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - double *ptr_a01_dup; - - cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; - cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction - { - for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction - { - a01 = L + j; //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used in GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - a01 = L + j; //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///load 4x4 block of b11 - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - - //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 2) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - if(n_remainder == 1) - { - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - } - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - - //(Row1): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1] - ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2] - ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3] - - ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1] - ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2] - ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3] - - //(Row2)FMA operations - ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2] - ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3] - - ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2] - ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3] - - //(Row3)FMA operations - ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3] - - ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3] - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - } - } - } - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = 0; (j+D_NR-1) 120) - return BLIS_NOT_YET_IMPLEMENTED; - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - dim_t cs_b_offset[2]; //pre-calculated strides - - double ones = 1.0; - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - double *ptr_a01_dup; - - cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; - cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - ymm11 = _mm256_mul_pd(ymm11, ymm0); - - ymm15 = _mm256_mul_pd(ymm15, ymm0); - - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(row 3):FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); - - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); - - ymm10 = _mm256_mul_pd(ymm10, ymm0); - - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(Row 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); - - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); - - ymm9 = _mm256_mul_pd(ymm9, ymm0); - - ymm13 = _mm256_mul_pd(ymm13, ymm0); - - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) - - - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///load 4x4 block of b11 - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - - //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] - } - if(n_remainder == 2) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - } - if(n_remainder == 1) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] - } - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - //compute reciprocals of L(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3] - - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]) - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - - //extract a33 - ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - ymm11 = _mm256_mul_pd(ymm11, ymm0); - - ymm15 = _mm256_mul_pd(ymm15, ymm0); - - //extract a22 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(row 3):FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); - - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); - - ymm10 = _mm256_mul_pd(ymm10, ymm0); - - ymm14 = _mm256_mul_pd(ymm14, ymm0); - - //extract a11 - ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(Row 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); - - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); - - ymm9 = _mm256_mul_pd(ymm9, ymm0); - - ymm13 = _mm256_mul_pd(ymm13, ymm0); - - //extract a00 - ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0] - - ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0] - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - } - } - if(i<0) - i += D_NR; - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction - { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code end/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - - ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - ymm3 = _mm256_mul_pd(ymm3, ymm15); - - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - ymm2 = _mm256_mul_pd(ymm2, ymm15); - - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - ymm1 = _mm256_mul_pd(ymm1, ymm15); - - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - ymm0 = _mm256_mul_pd(ymm0, ymm15); - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) - - } - if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) - { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - } - if(n_remainder == 2) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - } - if(n_remainder == 1) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - } - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - - ///GEMM processing stars/// - - for(k = 0; k < k_iter; k++) - { - ptr_a01_dup = a01; - - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - - } - - ///GEMM code ends/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - - ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - ymm3 = _mm256_mul_pd(ymm3, ymm15); - - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - ymm2 = _mm256_mul_pd(ymm2, ymm15); - - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - ymm1 = _mm256_mul_pd(ymm1, ymm15); - - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - ymm0 = _mm256_mul_pd(ymm0, ymm15); - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) - } - - } - m_remainder -= 4; - i -= 4; - } -// if(i < 0) i = 0; - if(m_remainder) ///implementation for remainder rows - { - dtrsm_small_XAlB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); - } - return BLIS_SUCCESS; -} - -/*implements TRSM for the case XA = alpha * B - *A is lower triangular, unit-diagonal, no transpose - *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 - ***************** * - *b01*b11* * * * * - ^ * * * * * ^ * * - | ***************** | ******* - | * * * * * | * * * - | * * * * * a01* * * -b10 ***************** ************* - * * * * * * * * * - * * * * * * * * * - ***************** ******************* - -*/ -static err_t bli_dtrsm_small_XAlB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - dim_t D_MR = 8; //block dimension along the rows - dim_t D_NR = 4; //block dimension along the columns - - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - dim_t m_remainder = m % D_MR; //number of corner rows - dim_t n_remainder = n % D_NR; //number of corner columns - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - if(max(m,n) > 120) - return BLIS_NOT_YET_IMPLEMENTED; - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - dim_t cs_b_offset[2]; //pre-calculated strides - - double ones = 1.0; - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - double *ptr_a01_dup; - - cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; - cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - - //(row 3):FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); - - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); - - //(Row 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); - - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); - - //(Row 1): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3]) - - - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4) - - ///load 4x4 block of b11 - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation begins/// - - for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - - //subtract the calculated GEMM block from current TRSM block - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] - } - if(n_remainder == 2) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - } - if(n_remainder == 1) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] - } - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0] - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0] - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1] - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1] - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2] - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2] - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3] - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3] - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - ymm7 = _mm256_broadcast_sd((double const *)(&ones)); - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - //2nd col - a11 += 1; - ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1] - - //3rd col - a11 += 1; - ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2] - ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2] - - //4th col - a11 += 1; - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3] - - - ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3] - ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3] - ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3] - - //(row 3):FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); - - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); - - //(Row 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); - - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); - - //(Row 1): FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - } - } - if(i<0) - i += D_NR; - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction - { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - } - - ///GEMM code end/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //store(B11[x][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) - - } - if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) - { - a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - } - if(n_remainder == 2) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - } - if(n_remainder == 1) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - } - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - - ///GEMM processing stars/// - - for(k = 0; k < k_iter; k++) - { - ptr_a01_dup = a01; - - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3] - - a01 += 1; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM - - } - - ///GEMM code ends/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0] - ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0] - - //2nd col - a11 += cs_a; - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1] - - //3rd col - a11 += cs_a; - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2] - ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2] - - //4th col - a11 += cs_a; - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3] - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) - } - - } - m_remainder -= 4; - i -= 4; - } - if(m_remainder) - { - dtrsm_small_XAlB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); - } - return BLIS_SUCCESS; -} - - -/*implements TRSM for the case XA = alpha * B - *A is lower triangular, non-unit diagonal, no transpose - *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 - ***************** * - *b01*b11* * * * * - ^ * * * * * ^ * * - | ***************** | ******* - | * * * * * | * * * - | * * * * * a01* * * -b10 ***************** ************* - * * * * * * * * * - * * * * * * * * * - ***************** ******************* - -*/ -static err_t bli_dtrsm_small_XAutB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - dim_t D_MR = 8; //block dimension along the rows - dim_t D_NR = 4; //block dimension along the columns - - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - dim_t m_remainder = m % D_MR; //number of corner rows - dim_t n_remainder = n % D_NR; //number of corner columns - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - if(max(m,n) > 120) - return BLIS_NOT_YET_IMPLEMENTED; - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - dim_t cs_b_offset[2]; //pre-calculated strides - - double ones = 1.0; - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - double *ptr_a01_dup; - - cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; - cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - a11 += cs_a; - - //2nd col - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - - ymm7 = _mm256_broadcast_sd((double const *)&ones); - - //compute reciprocals of A(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - //extract a33 - ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - ymm11 = _mm256_mul_pd(ymm11, ymm7); - - ymm15 = _mm256_mul_pd(ymm15, ymm7); - - //extract a22 - ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row 3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); - - //(Row 3): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); - - ymm10 = _mm256_mul_pd(ymm10, ymm7); - - ymm14 = _mm256_mul_pd(ymm14, ymm7); - - //extract a11 - ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(ROW 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); - - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); - - ymm9 = _mm256_mul_pd(ymm9, ymm7); - - ymm13 = _mm256_mul_pd(ymm13, ymm7); - - //extract A00 - ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1):FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - ymm8 = _mm256_mul_pd(ymm8, ymm7); - - ymm12 = _mm256_mul_pd(ymm12, ymm7); - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) - - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - - a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] - } - if(n_remainder == 2) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - } - if(n_remainder == 1) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] - } - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - a11 += cs_a; - - //2nd col - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - - ymm7 = _mm256_broadcast_sd((double const *)&ones); - - //compute reciprocals of A(i,i) and broadcast in registers - ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - - ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - //extract a33 - ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - ymm11 = _mm256_mul_pd(ymm11, ymm7); - - ymm15 = _mm256_mul_pd(ymm15, ymm7); - - //extract a22 - ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row 3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); - - //(Row 3): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); - - ymm10 = _mm256_mul_pd(ymm10, ymm7); - - ymm14 = _mm256_mul_pd(ymm14, ymm7); - - //extract a11 - ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(ROW 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); - - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); - - ymm9 = _mm256_mul_pd(ymm9, ymm7); - - ymm13 = _mm256_mul_pd(ymm13, ymm7); - - //extract A00 - ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1):FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - ymm8 = _mm256_mul_pd(ymm8, ymm7); - - ymm12 = _mm256_mul_pd(ymm12, ymm7); - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - } - } - if(i<0) - i += D_NR; - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction - { - a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM - } - - ///GEMM code end/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - a11 += cs_a; - - //2nd col - ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - - ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - ymm3 = _mm256_mul_pd(ymm3, ymm15); - - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - ymm2 = _mm256_mul_pd(ymm2, ymm15); - - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - ymm1 = _mm256_mul_pd(ymm1, ymm15); - - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - ymm0 = _mm256_mul_pd(ymm0, ymm15); - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) - - } - if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) - { - - a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - } - if(n_remainder == 2) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - } - if(n_remainder == 1) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - } - - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code end/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - a11 += cs_a; - - //2nd col - ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - - ymm14 = _mm256_broadcast_sd((double const *)&ones); - - //compute reciprocals of A(i,i) and broadcast in registers - ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1] - ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3] - - ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3] - ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3] - - //extract a33 - ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]) - - ymm3 = _mm256_mul_pd(ymm3, ymm15); - - //extract a22 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]) - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - ymm2 = _mm256_mul_pd(ymm2, ymm15); - - //extract a11 - ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]) - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - ymm1 = _mm256_mul_pd(ymm1, ymm15); - - //extract A00 - ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]) - ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]) - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - ymm0 = _mm256_mul_pd(ymm0, ymm15); - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) - } - - } - m_remainder -= 4; - i -= 4; - } - if(m_remainder) ///implementation for remainder rows - { - dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b); - } - return BLIS_SUCCESS; -} - -/*implements TRSM for the case XA = alpha * B - *A is lower triangular, unit-diagonal, no transpose - *dimensions: X:mxn A:nxn B: mxn - */ - -/* <---b11 <---a11 - ***************** * - *b01*b11* * * * * - ^ * * * * * ^ * * - | ***************** | ******* - | * * * * * | * * * - | * * * * * a01* * * -b10 ***************** ************* - * * * * * * * * * - * * * * * * * * * - ***************** ******************* - -*/ -static err_t bli_dtrsm_small_XAutB_unitDiag( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - dim_t D_MR = 8; //block dimension along the rows - dim_t D_NR = 4; //block dimension along the columns - - dim_t m = bli_obj_length(b); //number of rows - dim_t n = bli_obj_width(b); //number of columns - dim_t m_remainder = m % D_MR; //number of corner rows - dim_t n_remainder = n % D_NR; //number of corner columns - dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A - dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B - - if(max(m,n) > 120) - return BLIS_NOT_YET_IMPLEMENTED; - - dim_t i, j, k; //loop variablse - dim_t k_iter; //determines the number of GEMM operations to be done - dim_t cs_b_offset[2]; //pre-calculated strides - - double ones = 1.0; - - double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha - double *L = a->buffer; //pointer to matrix A - double *B = b->buffer; //pointer to matrix B - - double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks - double *ptr_a01_dup; - - cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2; - cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3; - - //ymm scratch reginsters - __m256d ymm0, ymm1, ymm2, ymm3; - __m256d ymm4, ymm5, ymm6, ymm7; - __m256d ymm8, ymm9, ymm10, ymm11; - __m256d ymm12, ymm13, ymm14, ymm15; - __m256d ymm16; - - for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction - { - for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction - { - a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0] - ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3] - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - a11 += cs_a; - - //2nd col - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - //(Row 3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); - - //(Row 3): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); - - //(ROW 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); - - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); - - //(Row 1):FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - _mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3]) - - } - if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR) - { - - a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM - b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4) - - ymm0 = _mm256_setzero_pd(); - ymm1 = _mm256_setzero_pd(); - ymm2 = _mm256_setzero_pd(); - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //broadcast 1st row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row - - //load 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0] - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1] - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3]) - - //broadcast 3rd row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A01 - - //load next 8x2 block of B10 - ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2]) - ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2]) - ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3]) - ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3]) - - ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3]) - - //broadcast 4th row of A01 - ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A01 - - ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0]) - ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1]) - ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2]) - ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3]) - - ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3]) - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code ends/// - - ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); - //load 8x4 block of B11 - if(n_remainder == 3) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0] - ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2] - } - if(n_remainder == 2) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0] - ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0] - ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1] - } - if(n_remainder == 1) - { - ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1] - ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1] - ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2] - ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2] - ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3] - ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3] - ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0] - ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0] - } - - ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0 - ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1 - ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2 - ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3 - - ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4 - ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5 - ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6 - ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - //1st col - ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - a11 += cs_a; - - //2nd col - ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - //(Row 3): FMA operations - ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10); - ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8); - - //(Row 3): FMA operations - ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14); - ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12); - - //(ROW 2): FMA operations - ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9); - ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8); - - ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13); - ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12); - - //(Row 1):FMA operations - ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8); - - ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12); - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0]) - } - - } - } - if(i<0) - i += D_NR; - if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4) - { - for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction - { - a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM - } - - ///GEMM code end/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - a11 += cs_a; - - //2nd col - ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - _mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3]) - - } - if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR) - { - - a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM - a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM - b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM - b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM - - k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4) - - ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha - ///GEMM for previous blocks /// - - ///load 4x4 block of b11 - if(n_remainder == 3) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - } - if(n_remainder == 2) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - } - if(n_remainder == 1) - { - ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2] - ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3] - ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0] - } - - //multiply by alpha - ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha - ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha - ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha - ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha - - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - - ///GEMM implementation starts/// - - for(k = 0; k < k_iter; k++) //loop for number of GEMM operations - { - ptr_a01_dup = a01; - - //load 4x4 bblock of b10 - ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0] - ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1] - ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2] - ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3] - - //broadcast 1st row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3]) - - //broadcast 2nd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3]) - - //braodcast 3rd row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3]) - - //broadcast 4th row of A01 - ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0] - ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1] - ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2] - ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3] - - a01 += cs_a; //move to next row of A - - ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0]) - ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1]) - ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2]) - ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3]) - - - b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM - a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM - } - - ///GEMM code end/// - - ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4 - ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5 - ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6 - ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7 - - ///implement TRSM/// - - ///read 4x4 block of A11/// - - - //1st col - ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0] - - a11 += cs_a; - - //2nd col - ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - - a11 += cs_a; - - //3rd col - ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - - a11 += cs_a; - - //4th col - ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1] - ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1] - ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1] - ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1] - - //(Row 3): FMA operations - ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2); - ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0); - - //(ROW 2): FMA operations - ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1); - ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0); - - //(Row 1):FMA operations - ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0); - - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1]) - _mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2])) - _mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0]) - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0]) - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1]) - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0]) - } - - } - m_remainder -= 4; - i -= 4; - } - if(m_remainder) - { - dtrsm_small_XAutB_unitDiag(a->buffer, b->buffer,AlphaVal, m_remainder, n, cs_a, cs_b); - } - return BLIS_SUCCESS; -} - - -/* - * AX = Alpha*B, Single precision, A: lower triangular - * This kernel implementation supports matrices A and B such that m is equal to BLI_AlXB_M_SP and n is mutiple of 8 - */ - -static err_t bli_strsm_small_AlXB ( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - obj_t alpha, beta; // gemm parameters - obj_t Ga, Gb, Gc; // for GEMM - int m = bli_obj_length(b); // number of rows of matrix B - int n = bli_obj_width(b); // number of columns of matrix B - - int lda = bli_obj_col_stride(a); // column stride of A - int ldb = bli_obj_col_stride(b); // column stride of B - - int rsa = bli_obj_row_stride(a); // row stride of A - int rsb = bli_obj_row_stride(b); // row stride of B - - int i = 0; - int j; - int blk_size = 8; - int isUnitDiag = bli_obj_has_unit_diag(a); - - float alphaVal; - float *L = a->buffer; - float *B = b->buffer; - - if (m != BLI_AlXB_M_SP || (n&7) != 0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if ( (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); - - /* Small _GEMM preparation code */ - bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &alpha ); - bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &beta ); - - /* B = B - A*B */ - bli_setsc( -(1.0), 0.0, &alpha ); - bli_setsc( (1.0), 0.0, &beta ); - - - bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, blk_size, a->buffer, rsa, lda, &Ga); - bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gb); - bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gc); - - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Ga ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gb ); - bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gc ); - - //first block of trsm - Gb.buffer = (void*)(B + i); - - //trsm of first 8xn block - if (alphaVal != 1) - { - if (isUnitDiag == 0) - { - blis_strsm_microkernel_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - fp_blis_strsm_microkernel = blis_strsm_microkernel; - } - else - { - blis_strsm_microkernel_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; - } - bli_setsc( alphaVal, 0.0, &beta ); - } - else - { - if (isUnitDiag == 0) - { - blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - fp_blis_strsm_microkernel = blis_strsm_microkernel; - } - else - { - blis_strsm_microkernel_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag; - } - } - - //gemm update - for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT - { - Ga.buffer = (void*)(L + j + i*lda); - Gc.buffer = (void*)(B + j); - - bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb - } - - //trsm of remaining blocks - for (i = blk_size; i < m; i += blk_size) - { - Gb.buffer = (void*)(B + i); - - fp_blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - - for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT - { - Ga.buffer = (void*)(L + j + i*lda); - Gc.buffer = (void*)(B + j); - - bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb - } - - } // End of for loop - i - - return BLIS_SUCCESS; -} - - - -/* - * XA' = Alpha*B, Single precision, A: lower triangular - * This kernel implementation supports matrices A and B such that - * m and n are multiples of 8 and n is less than or equal to BLI_XAltB_N_SP - */ -static err_t bli_strsm_small_XAltB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - int m = bli_obj_length(a); // number of rows of matrix B - int n = bli_obj_length(b); // number of columns of matrix B - - int lda = bli_obj_col_stride(a); // column stride of A - int ldb = bli_obj_col_stride(b); // column stride of B - - int rsa = bli_obj_row_stride(a); // row stride of A - int rsb = bli_obj_row_stride(b); // row stride of B - - int i = 0; - int isUnitDiag = bli_obj_has_unit_diag(a); - - float alphaVal; - float *L = a->buffer; - float *B = b->buffer; - - if ((m&7) != 0 || (n&7) != 0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if ( n > BLI_XAltB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); - - if (alphaVal != 1) - { - if (isUnitDiag == 0) - { - trsm_XAtB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - else - { - trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - } - else - { - if (isUnitDiag == 0) - { - trsm_XAtB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - else - { - trsm_XAtB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - } - return BLIS_SUCCESS; -} - -/* - * A'X = Alpha*B, Single precision, A: upper triangular - * This kernel implementation supports matrices A and B such that - * m and n are multiples of 8, m is less than or equal to BLI_AutXB_M_SP and n is less than or equal to BLI_AutXB_N_SP - */ -static err_t bli_strsm_small_AutXB( - side_t side, - obj_t* AlphaObj, - obj_t* a, - obj_t* b, - cntx_t* cntx, - cntl_t* cntl - ) -{ - int m = bli_obj_width(a); // number of rows of matrix A (since At, so width is taken) - int n = bli_obj_width(b); // number of columns of matrix B - - int lda = bli_obj_col_stride(a); // column stride of A - int ldb = bli_obj_col_stride(b); // column stride of B - - int rsa = bli_obj_row_stride(a); // row stride of A - int rsb = bli_obj_row_stride(b); // row stride of B - - int i = 0; - int isUnitDiag = bli_obj_has_unit_diag(a); - - float alphaVal; - float *L = a->buffer; - float *B = b->buffer; - - if ((m&7) != 0 || (n&7) != 0) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - if ( m > BLI_AutXB_M_SP || n > BLI_AutXB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM ) - { - return BLIS_NOT_YET_IMPLEMENTED; - } - - alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj)); - - if (alphaVal != 1) - { - if (isUnitDiag == 0) - { - trsm_AutXB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - else - { - trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal); - } - } - else - { - if (isUnitDiag == 0) - { - trsm_AutXB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - else - { - trsm_AutXB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb); - } - } - return BLIS_SUCCESS; -} - -///////////////////////////// AX=B /////////////////////////////// -static void blis_strsm_microkernel_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) -{ - float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags; - __m256 alphaReg; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols -} - -static void blis_strsm_microkernel_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal) -{ - //float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags; - __m256 alphaReg; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - alphaReg = _mm256_broadcast_ss((float const *)&alphaVal); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //8th col - //ptr_l += cs_l; - //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - - //end loop of cols -} - -static void blis_strsm_microkernel_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - //reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //8th col - //ptr_l += cs_l; - //mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - //mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - //mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - //reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - //mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - //mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - //mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - //mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - //mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - //mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - //mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - //mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - //mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - //mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - //mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - //mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - //mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - //mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - //mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - //mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - //mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - //mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - //mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - //mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols -} - -static void blis_strsm_microkernel(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int j; - int cs_b_offset[6]; - //int row2, row4, row6; - float *ptr_b_dup; - - //70 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_cols[8]; - __m256 mat_a_cols_rearr[36]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - reciprocal_diags = _mm256_broadcast_ss((float const *)&ones); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - //_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0); - //row2 = (cs_l << 1); - //row4 = (cs_l << 2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - //_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0); - //row6 = row2 + row4; - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - //_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - //_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - //_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - //_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0); - - //reciprocal_diags = _mm256_loadu_ps((float const *)ones); - - //read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L - /*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l); - ptr_l += cs_l; - mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/ - - //Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers - //tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually. - //mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]); - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3)); - mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4)); - mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5)); - mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6)); - mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7)); - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //4rth col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //5th col - ptr_l += cs_l; - mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //6th col - ptr_l += cs_l; - mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - //7th col - ptr_l += cs_l; - mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - numCols_b -= 8; // blk_width = 8 - - //compute reciprocals of L(i,i) and broadcast in registers - mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]); - mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]); - - //mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55); - //mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55); - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC); - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20); - - //reciprocal of diagnol elements - reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]); - - //Start loop for cols of B to be processed in size of blk_width - for (j = 0; j < numCols_b; j += 8) - { - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Read next set of B columns - ptr_b += (cs_b + cs_b_offset[5]); - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5])); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols - } - - //Last block trsm processing - ptr_b_dup = ptr_b; - - /*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/ - - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44); - mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE); - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE); -#else - mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E); - mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E); - mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC); - mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33); - mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC); - mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20); - mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31); - mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20); - mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]); - //end loop of cols -} -static void blis_dtrsm_microkernel_alpha(double *ptr_l, - double *ptr_b, - int m, - int n, - int rs_l, - int rs_b, - int cs_l, - int cs_b, - double alphaVal - ) -{ - int j; - int n_remainder = n%4; - int cs_b_offset[2]; - double *ptr_b_dup; - double ones = 1.0; - __m256d mat_b_col[4]; - __m256d mat_b_rearr[4]; - __m256d mat_a_cols[4]; - __m256d mat_a_cols_rearr[10]; - __m256d mat_a_diag_inv[4]; - __m256d reciprocal_diags; - __m256d alphaReg; - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - - reciprocal_diags = _mm256_broadcast_sd((double const *)&ones); - alphaReg = _mm256_broadcast_sd((double const *)&alphaVal); - - //if(m % 4 == 0) - //{ - //1st col - mat_a_cols_rearr[0] = _mm256_broadcast_sd((double const *)(ptr_l+0)); - mat_a_cols_rearr[1] = _mm256_broadcast_sd((double const *)(ptr_l+1)); - mat_a_cols_rearr[3] = _mm256_broadcast_sd((double const *)(ptr_l+2)); - mat_a_cols_rearr[6] = _mm256_broadcast_sd((double const *)(ptr_l+3)); - - //2nd col - ptr_l += cs_l; - mat_a_cols_rearr[2] = _mm256_broadcast_sd((double const *)(ptr_l + 1)); - mat_a_cols_rearr[4] = _mm256_broadcast_sd((double const *)(ptr_l + 2)); - mat_a_cols_rearr[7] = _mm256_broadcast_sd((double const *)(ptr_l + 3)); - - //3rd col - ptr_l += cs_l; - mat_a_cols_rearr[5] = _mm256_broadcast_sd((double const *)(ptr_l + 2)); - mat_a_cols_rearr[8] = _mm256_broadcast_sd((double const *)(ptr_l + 3)); - - //4th col - ptr_l += cs_l; - mat_a_cols_rearr[9] = _mm256_broadcast_sd((double const *)(ptr_l + 3)); - //compute reciprocals of L(i,i) and broadcast in registers - mat_a_diag_inv[0] = _mm256_unpacklo_pd(mat_a_cols_rearr[0], mat_a_cols_rearr[2]); - mat_a_diag_inv[1] = _mm256_unpacklo_pd(mat_a_cols_rearr[5], mat_a_cols_rearr[9]); - - mat_a_diag_inv[0] = _mm256_blend_pd(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x0C); - reciprocal_diags = _mm256_div_pd(reciprocal_diags, mat_a_diag_inv[0]); - - for(j = 0;(j+3) < n; j += 4) - { - ptr_b_dup = ptr_b; - /*Shuffle to rearrange/transpose 8x4 block of B into contiguous row-wise registers*/ - - //read first set of 4x4 block of B into registers - mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b); - mat_b_col[1] = _mm256_loadu_pd((double const *)(ptr_b + (cs_b))); - //_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0); - mat_b_col[2] = _mm256_loadu_pd((double const *)(ptr_b + cs_b_offset[0])); - //_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0); - mat_b_col[3] = _mm256_loadu_pd((double const *)(ptr_b + cs_b_offset[1])); - - ////unpacklow//// - mat_b_rearr[1] = _mm256_unpacklo_pd(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[3] = _mm256_unpacklo_pd(mat_b_col[2], mat_b_col[3]); - - //rearrange low elements - mat_b_rearr[0] = _mm256_permute2f128_pd(mat_b_rearr[1],mat_b_rearr[3],0x20); - mat_b_rearr[2] = _mm256_permute2f128_pd(mat_b_rearr[1],mat_b_rearr[3],0x31); - - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg); - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_pd(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_pd(mat_b_col[2], mat_b_col[3]); - - //rearrange high elements - mat_b_rearr[1] = _mm256_permute2f128_pd(mat_b_col[0],mat_b_col[1],0x20); - mat_b_rearr[3] = _mm256_permute2f128_pd(mat_b_col[0],mat_b_col[1],0x31); - - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg); - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg); - //extract a00 - mat_a_diag_inv[0] = _mm256_permute_pd(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_pd(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], mat_a_diag_inv[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_pd(reciprocal_diags, 0x03); - mat_a_diag_inv[1] = _mm256_permute2f128_pd(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - mat_b_rearr[1] = _mm256_fnmadd_pd(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_pd(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], mat_a_diag_inv[1]); - - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_pd(reciprocal_diags, 0x00); - mat_a_diag_inv[2] = _mm256_permute2f128_pd(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x11); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_pd(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_pd(reciprocal_diags, 0x0C); - mat_a_diag_inv[3] = _mm256_permute2f128_pd(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x11); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], mat_a_diag_inv[3]); - - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[1] = _mm256_unpacklo_pd(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[3] = _mm256_unpacklo_pd(mat_b_rearr[2], mat_b_rearr[3]); - - //rearrange low elements - mat_a_cols[0] = _mm256_permute2f128_pd(mat_a_cols[1],mat_a_cols[3],0x20); - mat_a_cols[2] = _mm256_permute2f128_pd(mat_a_cols[1],mat_a_cols[3],0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_pd(mat_b_rearr[0], mat_b_rearr[1]); - - mat_b_rearr[1] = _mm256_unpackhi_pd(mat_b_rearr[2], mat_b_rearr[3]); - - //rearrange high elements - mat_a_cols[1] = _mm256_permute2f128_pd(mat_b_rearr[0],mat_b_rearr[1],0x20); - mat_a_cols[3] = _mm256_permute2f128_pd(mat_b_rearr[0],mat_b_rearr[1],0x31); - - //Read next set of B columns - ptr_b += (cs_b+cs_b_offset[1]); - _mm256_storeu_pd((double *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_pd((double *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_pd((double *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - _mm256_storeu_pd((double *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]); - - } - ptr_b_dup = ptr_b; - if(n_remainder == 3) - { - - //read first set of 4x4 block of B into registers - mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b); - mat_b_col[1] = _mm256_loadu_pd((double const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_loadu_pd((double const *)(ptr_b + cs_b_offset[0])); - mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 2) - { - //read first set of 4x4 block of B into registers - mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b); - mat_b_col[1] = _mm256_loadu_pd((double const *)(ptr_b + (cs_b))); - mat_b_col[2] = _mm256_broadcast_sd((double const *)&ones); - mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones); - } - if(n_remainder == 1) - { - //read first set of 4x4 block of B into registers - mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b); - mat_b_col[1] = _mm256_broadcast_sd((double const *)&ones); - mat_b_col[2] = _mm256_broadcast_sd((double const *)&ones); - mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones); - } - /*Shuffle to rearrange/transpose 8x4 block of B into contiguous row-wise registers*/ - ////unpacklow//// - mat_b_rearr[1] = _mm256_unpacklo_pd(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[3] = _mm256_unpacklo_pd(mat_b_col[2], mat_b_col[3]); - //rearrange low elements - mat_b_rearr[0] = _mm256_permute2f128_pd(mat_b_rearr[1],mat_b_rearr[3],0x20); - mat_b_rearr[2] = _mm256_permute2f128_pd(mat_b_rearr[1],mat_b_rearr[3],0x31); - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg); - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg); - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_pd(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_pd(mat_b_col[2], mat_b_col[3]); - //rearrange high elements - mat_b_rearr[1] = _mm256_permute2f128_pd(mat_b_col[0],mat_b_col[1],0x20); - mat_b_rearr[3] = _mm256_permute2f128_pd(mat_b_col[0],mat_b_col[1],0x31); - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg); - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg); - //extract a00 - mat_a_diag_inv[0] = _mm256_permute_pd(reciprocal_diags, 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_pd(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_pd(reciprocal_diags, 0x03); - mat_a_diag_inv[1] = _mm256_permute2f128_pd(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0) - mat_b_rearr[1] = _mm256_fnmadd_pd(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_pd(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_pd(reciprocal_diags, 0x00); - mat_a_diag_inv[2] = _mm256_permute2f128_pd(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x11); - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_pd(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_pd(reciprocal_diags, 0x0C); - mat_a_diag_inv[3] = _mm256_permute2f128_pd(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x11); - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], mat_a_diag_inv[3]); - //--> Transpose and store results of columns of B block <--// - ////unpacklow//// - mat_a_cols[1] = _mm256_unpacklo_pd(mat_b_rearr[0], mat_b_rearr[1]); - mat_a_cols[3] = _mm256_unpacklo_pd(mat_b_rearr[2], mat_b_rearr[3]); - //rearrange low elements - mat_a_cols[0] = _mm256_permute2f128_pd(mat_a_cols[1],mat_a_cols[3],0x20); - mat_a_cols[2] = _mm256_permute2f128_pd(mat_a_cols[1],mat_a_cols[3],0x31); - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_pd(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_pd(mat_b_rearr[2], mat_b_rearr[3]); - //rearrange high elements - mat_a_cols[1] = _mm256_permute2f128_pd(mat_b_rearr[0],mat_b_rearr[1],0x20); - mat_a_cols[3] = _mm256_permute2f128_pd(mat_b_rearr[0],mat_b_rearr[1],0x31); - //Store the computed B columns - if(n_remainder == 3) - { - _mm256_storeu_pd((double *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_pd((double *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - _mm256_storeu_pd((double *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]); - } - if(n_remainder == 2) - { - _mm256_storeu_pd((double *)ptr_b_dup, mat_a_cols[0]); - _mm256_storeu_pd((double *)(ptr_b_dup + (cs_b)), mat_a_cols[1]); - } - if(n_remainder == 1) - { - _mm256_storeu_pd((double *)ptr_b_dup, mat_a_cols[0]); - } - - //} - -} - -#if OPT_CACHE_BLOCKING_L1 //new intrinsic kernels -static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - __m256 alphaReg; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0) - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - __m256 alphaReg; - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0) - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - i = i1 + r; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - ptr_l_dup = ptr_l; - i4 = i2 + r; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - i4 = k >> 3; - ptr_l_dup += cs_l; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7)); - ptr_l_dup += cs_l; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - } - i2 += cs_b_offset[6]; - i += cs_l_offset[6]; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - { - i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2)); - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done - -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} -#else //rel 1.0 intrisic kernels (NOT OPT_CACHE_BLOCKING_L1) -static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - __m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - __m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - __m256 alphaReg; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); - mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); - mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); - mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); - mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); - mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); - mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); - mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += cs_l_offset[6]; - mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); - mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); - mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); - mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); - mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); - mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); - mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); - mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); - - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - //__m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - //(Row0) - mat_b_col[0] = mat_b_rearr[0][0]; - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): already done - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} - -static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - //float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[16][8]; - //__m256 mat_a_cols_rearr[8]; - __m256 mat_a_blk_elems[64]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - __m256 alphaReg; - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg); - mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg); - mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg); - mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg); - mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg); - mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg); - mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg); - mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg); - - //(Row0) - mat_b_col[0] = mat_b_rearr[0][0]; - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b) - mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b) - mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b) - mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b) - mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b) - mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b) - mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]); - - //i += cs_b_offset[6]; - //ptr_b_dup += cs_b_offset[6]; - i += 8; - ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += 8; - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += cs_b_offset[6]; - i1 += cs_b_offset[6]; - i3 += cs_l_offset[6]; - - i = 0; - i2 = 0; - for (k = 0; k < numCols_b; k += 8) - { - i = i1 + k; - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg); - mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg); - mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg); - mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg); - mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg); - mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg); - mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg); - mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg); - - i2++; - } - - i = 0; - i2 = 0; - for (l = 0; l < j; l += 8) // move across m - { - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4)); - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7)); - - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1)); - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5)); - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7)); - - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i)); - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2)); - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3)); - mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4)); - mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5)); - mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6)); - mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7)); - - // _mm256_permute2f128_ps() - - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i)); - mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1)); - mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2)); - mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3)); - mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4)); - mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5)); - mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6)); - mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7)); - - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i)); - mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1)); - mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2)); - mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3)); - mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4)); - mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5)); - mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6)); - mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7)); - - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i)); - mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1)); - mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2)); - mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3)); - mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4)); - mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5)); - mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6)); - mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7)); - - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i)); - mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1)); - mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2)); - mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3)); - mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4)); - mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5)); - mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6)); - mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7)); - - i += cs_l_offset[6]; - - for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - { - /////////////////// Partial Lower 8x8 block trsm of B - - i4 = i2 + k; - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - i4 = k >> 3; - - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b) - - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b) - mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b) - mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b) - mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b) - mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b) - mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b) - mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b) - mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b) - - //end loop of cols - } - i2 += cs_b_offset[6]; - } - - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - i += cs_l; - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7)); - - k = 0; - for (i = 0; i < numCols_b; i+=8) - { - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A - - //(Row0): already done - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b) - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b) - - //////////////////////////////////////////////////////////////////////////////// - - //Store the computed B columns - - _mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - } - - - } - ///////////////////loop ends ///////////////////// -} -#endif //OPT_CACHE_BLOCKING_L1 - -//////////////////////////// AutX=B /////////////////////// -static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += 8; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ -#endif - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - float ones = 1.0; - int i, i1, i2, i3, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - __m256 mat_a_diag_inv[8]; - __m256 reciprocal_diags[2]; - __m256 alphaReg; - - reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones)); - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - //read diag elems of L 16x16 block - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - - reciprocal_diags[1] = reciprocal_diags[0]; - - //pack first 8 diags together - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]); - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]); - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i3 = 0; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - //Read next 8x8 block of A to get diag elements - i3 += 8; - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]); - - //pack 8 diags of A together - reciprocal_diags[0] = reciprocal_diags[1]; - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1 - mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5 - mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3 - mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7 - mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7 - - //reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7 - reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]); - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - //extract diag a00 from a - mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00); - //mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]); - - //extract diag a11 from a - mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00); - //mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]); - - //extract diag a22 from a - mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00); - //mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]); - - //extract diag a33 from a - mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00); - //mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]); - - //extract diag a44 from a - mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00); - mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11); - //mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]); - - //extract diag a55 from a - mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55); - mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11); - //mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]); - - //extract diag a66 from a - mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA); - mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11); - //mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]); - - //extract diag a77 from a - mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF); - mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11); - //mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]); - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - //Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - //Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - //Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]); - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b) -{ - //float ones = 1.0; - int i, i1, i2, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - - //(Row0) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ -#endif - - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done - -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} - -static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha) -{ - //float ones = 1.0; - int i, i1, i2, i4, j, k, l, r; - int cs_b_offset[7]; - int cs_l_offset[7]; - float *ptr_b_dup, *ptr_l_dup; - - //57 number of ymm(256 bits) registers used - __m256 mat_b_col[8]; - __m256 mat_b_rearr[8]; - __m256 mat_a_blk_elems[8]; - //__m256 mat_a_diag_inv[8]; - //__m256 reciprocal_diags[2]; - __m256 alphaReg; - alphaReg = _mm256_broadcast_ss((float const *)&alpha); - - // ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- // - - //L matrix offsets - cs_l_offset[0] = (cs_l << 1); - cs_l_offset[1] = cs_l + cs_l_offset[0]; - cs_l_offset[2] = (cs_l << 2); - cs_l_offset[3] = cs_l + cs_l_offset[2]; - cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2]; - cs_l_offset[5] = cs_l + cs_l_offset[4]; - cs_l_offset[6] = (cs_l_offset[5] + cs_l); - - cs_b_offset[0] = (cs_b << 1); - cs_b_offset[1] = cs_b + cs_b_offset[0]; - cs_b_offset[2] = (cs_b << 2); - cs_b_offset[3] = cs_b + cs_b_offset[2]; - cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2]; - cs_b_offset[5] = cs_b + cs_b_offset[4]; - cs_b_offset[6] = (cs_b_offset[5] + cs_b); - -#if 0 - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3)); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4)); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5)); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6)); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7)); - - //Broadcast A21 to A71 to registers - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2)); - mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3)); - mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4)); - mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5)); - mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6)); - mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7)); - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3)); - mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4)); - mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5)); - mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6)); - mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7)); - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4)); - mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5)); - mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6)); - mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7)); - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5)); - mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6)); - mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7)); - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6)); - mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7)); - - //Broadcast A76 to register - mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7)); -#endif - - - /***************** first set of 8 rows of B processing starts *****************/ - ptr_b_dup = ptr_b; - i = 0; - for (j = 0; j < numCols_b; j += 8) - { - /////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A - //read 8x8 block of B into registers - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); - - //(Row0) - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5])); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5])); - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5])); - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5])); - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5])); - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5])); - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b) - - - - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5])); - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]); - - i += cs_b_offset[6]; - ptr_b_dup += cs_b_offset[6]; - //i += 8; - //ptr_b_dup += 8; - } - - //c = 0; - /***************** first set of 8 cols of B processing done *****************/ - ptr_b_dup = ptr_b; - i1 = 0; - //Start loop for cols of B to be processed in size of blk_width - for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row - { - ptr_l += cs_l_offset[6]; - - - //ptr_b += j; - //ptr_b_dup += 8; - ptr_b_dup += 8; - i1 += 8; - i = i1; - i2 = 0; - - for (r = 0; r < numCols_b; r += GEMM_BLK_V1) - { -#if GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]); - mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]); - mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]); - mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - - ////unpackhigh//// - mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]); - mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]); - mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]); - mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - /* transpose steps end */ - - mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg); - mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg); - mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg); - mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg); - mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg); - mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg); - mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg); - mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg); -#endif - - //i = 0; - ptr_l_dup = ptr_l; - i4 = i2; - for (l = 0; l < j; l += 8) // move across m - { - //for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m) - //{ - /////////////////// Partial Lower 8x8 block trsm of B - //Read current 8 cols of B columns from specified 8x8 current-block of B - mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4); - mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b)); - mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0])); - mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1])); - mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2])); - mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3])); - mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4])); - mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5])); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]); - mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]); - mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]); - mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44); - mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE); - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE); -#else - mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E); - mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E); - mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC); - mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33); - mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC); - mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31); - /* transpose steps end */ - - //Broadcast A8,0 to A15,0 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i4 = k >> 3; - ptr_l_dup++; - -#if GEMM_ACCUM_A - //(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]); - mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]); - mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]); - mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]); - mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]); - mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]); - mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]); - mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]); -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,2 to A15,2 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,3 to A15,3 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,4 to A15,4 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,5 to A15,5 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,6 to A15,6 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A8,7 to A15,7 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - ptr_l_dup++; -#if GEMM_ACCUM_A - //(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b) - mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b) -#endif - //end loop of cols - //} - //i2 += cs_b_offset[6]; - i4 += 8; - } - //trsm solve - - k = 0; - //for (i2 = 0; i2 < numCols_b; i2 += 8) - //{ - //i2 = i1 + r; - /////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A -#if !GEMM_ACCUM_A - //Read 8 cols of B columns of Block-to-be-solved - mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i); - mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i)); - mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i)); - mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i)); - mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i)); - mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i)); - mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i)); - mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i)); - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg); - mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg); - mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg); - mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg); - mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg); - mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg); - mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg); - mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg); -#endif - //Broadcast A10 to A70 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l)); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4])); - mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5])); - //i += cs_l; - -#if GEMM_ACCUM_A - //(Row0): already done - -#else - mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]); -#endif - -#if GEMM_ACCUM_A - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#else - mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]); - mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]); - mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]); - mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]); - mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]); - mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]); - mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]); - - //(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0) - mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b) -#endif - //Broadcast A21 to A71 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4])); - mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0) - mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A32 to A72 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4])); - mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0) - mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A43 to A73 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4])); - mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0) - mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A54 to A74 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4])); - mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0) - mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A65 to A75 to registers - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4])); - mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5])); - //i += cs_l; - - - - //(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0) - mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b) - - //Broadcast A76 to register - mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5])); - - - - //(Row7): FMA operations of b7 with elements of index (7, 0) - mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b) - - - - //////////////////////////////////////////////////////////////////////////////// - - /* transpose steps start */ - ////unpacklow//// - mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange low elements -#if REARRANGE_SHFL == 1 - mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44); - mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE); - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE); -#else - mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E); - mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E); - mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC); - mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33); - mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC); - mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33); -#endif - //Merge rearranged low elements into complete rows - mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20); - mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31); - mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20); - mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31); - - ////unpackhigh//// - mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]); - mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]); - mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]); - mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]); - - //Rearrange high elements -#if REARRANGE_SHFL == 1 - mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44); - mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE); - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE); -#else - mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E); - mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E); - mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC); - mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33); - mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC); - mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33); -#endif - - //Merge rearranged high elements into complete rows - mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20); - mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31); - mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20); - mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31); - /* transpose steps end */ - - //Store the computed B columns - _mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]); - _mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]); - _mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]); - //printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k)); - k++; - //} - i += cs_b_offset[6]; - i2 += cs_b_offset[6]; - } - } //numRows of A - ///////////////////loop ends ///////////////////// -} -#endif