Files
blis/kernels/zen/3/bli_trsm_small.c
Field G. Van Zee 09bd4f4f12 Add err_t* "return" parameter to malloc functions.
Details:
- Added an err_t* parameter to memory allocation functions including
  bli_malloc_intl(), bli_calloc_intl(), bli_malloc_user(),
  bli_fmalloc_align(), and bli_fmalloc_noalign(). Since these functions
  already use the return value to return the allocated memory address,
  they can't communicate errors to the caller through the return value.
  This commit does not employ any error checking within these functions
  or their callers, but this sets up BLIS for a more comprehensive
  commit that moves in that direction.
- Moved the typedefs for malloc_ft and free_ft from bli_malloc.h to
  bli_type_defs.h. This was done so that what remains of bli_malloc.h
  can be included after the definition of the err_t enum. (This ordering
  was needed because bli_malloc.h now contains function prototypes that
  use err_t.)
- Defined bli_is_success() and bli_is_failure() static functions in
  bli_param_macro_defs.h. These functions provide easy checks for error
  codes and will be used more heavily in future commits.
- Unfortunately, the additional err_t* argument discussed above breaks
  the API for bli_malloc_user(), which is an exported symbol in the
  shared library. However, it's quite possible that the only application
  that calls bli_malloc_user()--indeed, the reason it is was marked for
  symbol exporting to begin with--is the BLIS testsuite. And if that's
  the case, this breakage won't affect anyone. Nonetheless, the "major"
  part of the so_version file has been updated accordingly to 4.0.0.
2021-03-31 17:09:36 -05:00

27822 lines
1.7 MiB

/*
BLIS
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2018-2019, Advanced Micro Devices, Inc.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
- Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
- Neither the name of The University of Texas at Austin nor the names
of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include "blis.h"
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
#include "immintrin.h"
#define GEMM_BLK_V1 8 //Block size to perform gemm and apply trsm
#define GEMM_ACCUM_A 1 //Peform B1=B1-(B0*A0) operation instead of B1'=(B0*A0) and then B1=B1-B1'
#define OPT_CACHE_BLOCKING_L1 1 //Perform trsm block-wise in blocks of GEMM_BLK_V1 instead of all columns of B together.
#define REARRANGE_SHFL 0 //Rearrange operations using blend or shuffle
#define BLI_AlXB_M_SP 16
#define BLI_XAltB_N_SP 128
#define BLI_AutXB_M_SP 64
#define BLI_AutXB_N_SP 128
// XA = B; A is lower-traingular; No transpose; double precision; non-unit diagonal
static err_t bli_dtrsm_small_XAlB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//XA = B; A is lower triabgular; No transpose; double precision; unit-diagonal
static err_t bli_dtrsm_small_XAlB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//XA = B; A is lower-triangular; A is transposed; double precision; non-unit-diagonal
static err_t bli_dtrsm_small_XAltB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//XA = B; A is lower-triangular; A is transposed; double precision; unit-diagonal
static err_t bli_dtrsm_small_XAltB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
// XA = B; A is upper triangular; No transpose; double presicion; non-unit diagonal
static err_t bli_dtrsm_small_XAuB
(
side_t side,
obj_t* alpha,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//XA = B; A is upper triangular; No transpose; double precision; unit-diagonal
static err_t bli_dtrsm_small_XAuB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//XA = B; A is upper-triangular; A is transposed; double precision; non-unit diagonal
static err_t bli_dtrsm_small_XAutB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//XA = B; A is upper-triangular; A is transposed; double precision; unit diagonal
static err_t bli_dtrsm_small_XAutB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//AX = B; A is lower triangular; No transpose; double precision; non-unit diagonal
static err_t bli_dtrsm_small_AlXB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//AX = B; A is lower triangular; No transpose; double precision; unit diagonal
static err_t bli_dtrsm_small_AlXB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
static void (*fp_blis_strsm_microkernel)( float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b
);
static void blis_strsm_microkernel( float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b
);
static void blis_strsm_microkernel_alpha( float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
float alphaVal
);
static void blis_strsm_microkernel_unitDiag( float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b
);
static void blis_strsm_microkernel_alpha_unitDiag( float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
float alphaVal
);
static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b);
static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
float alphaVal);
static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b);
static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
float alphaVal);
static void blis_dtrsm_microkernel( double *ptr_l,
double *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b
);
static void blis_dtrsm_microkernel_alpha( double *ptr_l,
double *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
double alphaVal
);
static void blis_dtrsm_microkernel_unitDiag( double *ptr_l,
double *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b
);
static void blis_dtrsm_microkernel_alpha_unitDiag( double *ptr_l,
double *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
double alphaVal
);
static void dtrsm_XAtB_block_allSmallSizedMatrices(double *ptr_l,
double *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b);
static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha(double *ptr_l,
double *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
double alphaVal);
static void dtrsm_XAtB_block_allSmallSizedMatrices_unitDiag(double *ptr_l,
double *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b);
static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(double *ptr_l,
double *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
double alphaVal);
static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b);
static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
float alpha);
static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b);
static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l,
float *ptr_b,
int numRows_lb,
int numCols_b,
int rs_l,
int rs_b,
int cs_l,
int cs_b,
float alpha);
//AX = B; A is lower triangular; No transpose; single precision
static err_t bli_strsm_small_AlXB
(
side_t side,
obj_t* alpha,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//A.'X = B; A is upper triangular; A has to be transposed; single precision
static err_t bli_strsm_small_AutXB
(
side_t side,
obj_t* alpha,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//XA.' = B; A is lower triangular; A has to be transposed; single precision
static err_t bli_strsm_small_XAltB
(
side_t side,
obj_t* alpha,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
//A.'X = B; A is upper triangular; A has to be transposed; double precision
static err_t bli_dtrsm_small_AutXB
(
side_t side,
obj_t* alpha,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
);
/*
* The bli_trsm_small implements unpacked version of TRSM
* Currently only column-major is supported, A & B are column-major
* Input: A: MxM (triangular matrix)
* B: MxN matrix
* Output: X: MxN matrix such that AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B
* Here the output X is stored in B
* The custom-kernel will be called only when M*(M+N)* sizeof(Matrix Elements) < L3 cache
*/
err_t bli_trsm_small
(
side_t side,
obj_t* alpha,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
#ifdef BLIS_ENABLE_MULTITHREADING
return BLIS_NOT_YET_IMPLEMENTED;
#endif
dim_t m = bli_obj_length(b);
dim_t n = bli_obj_width(b);
if(!(m && n))
return BLIS_SUCCESS;
// If alpha is zero, B matrix will become zero after scaling & hence solution is also zero matrix
if (bli_obj_equals(alpha, &BLIS_ZERO))
{
return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha
}
// We have to call matrix scaling if alpha != 1.0
// if row major format return. Check this again.
if ((bli_obj_row_stride(a) != 1) ||
(bli_obj_row_stride(b) != 1))
{
return BLIS_INVALID_ROW_STRIDE;
}
num_t dt = ((*b).info & (0x7 << 0));
// only float and double datatypes are supported as of now.
if (dt != BLIS_DOUBLE && dt != BLIS_FLOAT)
{
return BLIS_EXPECTED_REAL_DATATYPE;
}
// A is expected to be triangular in trsm
if (!bli_obj_is_upper_or_lower (a))
{
return BLIS_EXPECTED_TRIANGULAR_OBJECT;
}
// can use other control structs - even can use array of function pointers,
// indexed by a number with bits formed by f('side', 'uplo', 'transa', dt).
// In the below implementation, based on the number of finally implemented
// cases, can move the checks with more cases higher up.
if(side == BLIS_LEFT)
{
if(bli_obj_has_trans(a))
{
if(dt == BLIS_DOUBLE)
{
if(bli_obj_is_upper(a))
{
//return bli_dtrsm_small_AutXB(side, alpha, a, b, cntx, cntl);
return BLIS_NOT_YET_IMPLEMENTED;
}
else
{
//return bli_dtrsm_small_AltXB(side, alpha, a, b, cntx, cntl);
return BLIS_NOT_YET_IMPLEMENTED;
}
}
else
{
if(bli_obj_is_upper(a))
{
return bli_strsm_small_AutXB(side, alpha, a, b, cntx, cntl);
}
else
{
//return bli_strsm_small_AltXB(side, alpha, a, b, cntx, cntl);
return BLIS_NOT_YET_IMPLEMENTED;
}
}
}
else
{
if(dt == BLIS_DOUBLE)
{
if(bli_obj_is_upper(a))
{
//return bli_dtrsm_small_AuXB(side, alpha, a, b, cntx, cntl);
return BLIS_NOT_YET_IMPLEMENTED;
}
else
{
if(bli_obj_has_unit_diag(a))
return bli_dtrsm_small_AlXB_unitDiag(side, alpha, a, b, cntx, cntl);
else
return bli_dtrsm_small_AlXB(side, alpha, a, b, cntx, cntl);
}
}
else
{
if(bli_obj_is_upper(a))
{
//return bli_strsm_small_AuXB(side, alpha, a, b, cntx, cntl);
return BLIS_NOT_YET_IMPLEMENTED;
}
else
{
return bli_strsm_small_AlXB(side, alpha, a, b, cntx, cntl);
}
}
}
}
else
{
if(bli_obj_has_trans(a))
{
if(dt == BLIS_DOUBLE)
{
if(bli_obj_is_upper(a))
{
if(bli_obj_has_unit_diag(a))
return bli_dtrsm_small_XAutB_unitDiag(side, alpha, a, b, cntx, cntl);
else
return bli_dtrsm_small_XAutB(side, alpha, a, b, cntx, cntl);
}
else
{
if(bli_obj_has_unit_diag(a))
return bli_dtrsm_small_XAltB_unitDiag(side, alpha, a, b, cntx, cntl);
else
return bli_dtrsm_small_XAltB(side, alpha, a, b, cntx, cntl);
}
}
else
{
if(bli_obj_is_upper(a))
{
//return bli_strsm_small_XAutB(side, alpha, a, b, cntx, cntl);
return BLIS_NOT_YET_IMPLEMENTED;
}
else
{
return bli_strsm_small_XAltB(side, alpha, a, b, cntx, cntl);
}
}
}
else
{
if(dt == BLIS_DOUBLE)
{
if(bli_obj_is_upper(a))
{
if(bli_obj_has_unit_diag(a))
return bli_dtrsm_small_XAuB_unitDiag(side, alpha, a, b, cntx, cntl);
else
return bli_dtrsm_small_XAuB(side, alpha, a, b, cntx, cntl);
}
else
{
if(bli_obj_has_unit_diag(a))
return bli_dtrsm_small_XAlB_unitDiag(side, alpha, a, b, cntx, cntl);
else
return bli_dtrsm_small_XAlB(side, alpha, a, b, cntx, cntl);
}
}
else
{
if(bli_obj_is_upper(a))
{
//return bli_strsm_small_XAuB(side, alpha, a, b, cntx, cntl);
return BLIS_NOT_YET_IMPLEMENTED;
}
else
{
//return bli_strsm_small_XAlB(side, alpha, a, b, cntx, cntl);
return BLIS_NOT_YET_IMPLEMENTED;
}
}
}
}
return BLIS_NOT_YET_IMPLEMENTED;
};
/* TRSM scalar code for the case AX = alpha * B
* A is lower-triangular, non-unit-diagonal, no transpose
* Dimensions: A: mxm X: mxn B:mxn
*/
static err_t dtrsm_small_AlXB (
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
for (k = 0; k < M; k++)
{
double lkk_inv = 1.0/A[k+k*lda];
for (j = 0; j < N; j++)
{
B[k + j*ldb] *= lkk_inv;
for (i = k+1; i < M; i++)
{
B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb];
}
}
}// k -loop
return BLIS_SUCCESS;
}// end of function
/* TRSM scalar code for the case AX = alpha * B
* A is lower-triangular, unit-diagonal, no transpose
* Dimensions: A: mxm X: mxn B:mxn
*/
static err_t dtrsm_small_AlXB_unitDiag (
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
for (k = 0; k < M; k++)
{
for (j = 0; j < N; j++)
{
for (i = k+1; i < M; i++)
{
B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb];
}
}
}
return BLIS_SUCCESS;
}// end of function
/* TRSM scalar code for the case XA = alpha * B
* A is upper-triangular, non-unit-diagonal no transpose
* Dimensions: X:mxn A:nxn B:mxn
*/
static err_t dtrsm_small_XAuB (
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
for(k = 0; k < N; k++)
{
double lkk_inv = 1.0/A[k+k*lda];
for(i = 0; i < M; i++)
{
B[i+k*ldb] *= lkk_inv;
for(j = k+1; j < N; j++)
{
B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda];
}
}
}
return BLIS_SUCCESS;
}
/* TRSM scalar code for the case XA = alpha * B
* A is lower-triangular, non-unit triangular, no transpose
* Dimensions: X:mxn A:nxn B:mxn
*/
static err_t dtrsm_small_XAlB (
double *A,
double *B,
double alpha,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
for(j = 0; j < N; j++)
for(i = 0; i < M; i++)
B[i+j*ldb] *= alpha;
for(k = N;k--;)
{
double lkk_inv = 1.0/A[(k)+(k)*lda];
for(i = M;i--;)
{
B[(i)+(k)*ldb] *= lkk_inv;
for(j = k;j--;)
{
B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A[(k)+(j)*lda];
}
}
}
return BLIS_SUCCESS;
}
/* TRSM scalar code for the case XA = alpha * B
* A is lower-triangular, unit-diagonal, no transpose
*Dimensions: X:mxn A:nxn B:mxn
*/
static err_t dtrsm_small_XAlB_unitDiag(
double *A,
double *B,
double alpha,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
for(j = 0 ; j < N; j++)
for(i = 0; i < M; i++)
B[i+j*ldb] *= alpha;
double A_k_j;
for(k = N; k--;)
{
for(j = k; j--;)
{
A_k_j = A[(k)+(j)*lda];
for(i = M; i--;)
{
B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A_k_j;
}
}
}
return BLIS_SUCCESS;
}
/* TRSM scalar code for the case XA = alpha * B
*A is upper-triangular, non-unit-diagonal, A is transposed
* Dimensions: X:mxn A:nxn B:mxn
*/
static err_t dtrsm_small_XAutB (
double *A,
double *B,
double alpha,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
for(j = 0; j < N; j++)
for(i = 0; i < M; i++)
B[i+j*ldb] *=alpha;
for(k = N; k--;)
{
double lkk_inv = 1.0/A[(k)+(k)*lda];
for(i = M; i--;)
{
B[(i)+(k)*ldb] *= lkk_inv;
for(j = k; j--;)
{
B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A[(j)+(k)*lda];
}
}
}
return BLIS_SUCCESS;
}
/* TRSM scalar code for the case XA = alpha * B
* A is upper-triangular, unit-diagonal, A has to be transposed
* Dimensions: X:mxn A:nxn B:mxn
*/
static err_t dtrsm_small_XAutB_unitDiag(
double *A,
double *B,
double alpha,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
double A_k_j;
for(j = 0; j< N; j++)
for(i = 0; i< M; i++)
B[i+j*ldb] *= alpha;
for(k = N; k--;)
{
for(j = k; j--;)
{
A_k_j = A[(j)+(k)*lda];
for(i = M; i--;)
{
B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A_k_j;
}
}
}
return BLIS_SUCCESS;
}
/* TRSM scalar code for the case XA = alpha * B
* A is lower-triangular, non-unit-diagonal, A has to be transposed
* Dimensions: X:mxn A:nxn B:mxn
*/
static err_t dtrsm_small_XAltB (
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
for(k = 0; k < N; k++)
{
double lkk_inv = 1.0/A[k+k*lda];
for(i = 0; i < M; i++)
{
B[i+k*ldb] *= lkk_inv;
for(j = k+1; j < N; j++)
{
B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda];
}
}
}
return BLIS_SUCCESS;
}
/* TRSM scalar code for XA = alpha * B
* A is lower-triangular, unit-diagonal, A has to be transposed
* Dimensions: X:mxn A:nxn B:mxn
*/
static err_t dtrsm_small_XAltB_unitDiag(
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
for(k = 0; k < N; k++)
{
for(i = 0; i < M; i++)
{
for(j = k+1; j < N; j++)
{
B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda];
}
}
}
return BLIS_SUCCESS;
}
/* TRSM scalar code for the case XA = alpha * B
* A is upper-triangular, unit-diagonal, no transpose
* Dimensions: X:mxn A:nxn B:mxn
*/
static err_t dtrsm_small_XAuB_unitDiag (
double *A,
double *B,
dim_t M,
dim_t N,
dim_t lda,
dim_t ldb
)
{
dim_t i, j, k;
for(k = 0; k < N; k++)
{
for(i = 0; i < M; i++)
{
for(j = k+1; j < N; j++)
{
B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda];
}
}
}
return BLIS_SUCCESS;
}
/* TRSM for the case AX = alpha * B, Double precision
* A is lower-triangular, no-transpose, non-unit diagonal
* dimensions A: mxm X: mxn B: mxn
b01--->
* *****************
** * * * * *
* * * * * * *
* * *b01* * * *
* * * * * * *
a10 ****** b11 *****************
| * * * | * * * * *
| * * * | * * * * *
| *a10*a11* | *b11* * * *
v * * * v * * * * *
*********** *****************
* * * * * * * * *
* * * * * * * * *
* * * * * * * * *
* * * * * * * * *
**************** *****************
a11--->
*/
static err_t bli_dtrsm_small_AlXB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 4; //size of block along 'M' dimpension
dim_t D_NR = 8; //size of block along 'N' dimension
dim_t m = bli_obj_length(b); // number of rows of matrix B
dim_t n = bli_obj_width(b); // number of columns of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME)
|| (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N)
|| (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n<D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t m_remainder = m & 3; //number of remainder rows
dim_t n_remainder = n & 7; //number of remainder columns
dim_t cs_a = bli_obj_col_stride(a); // column stride of A
dim_t cs_b = bli_obj_col_stride(b); // column stride of B
dim_t i, j, k; //loop variables
dim_t k_iter; //number of times GEMM to be performed
double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha
double *L = a->buffer; //pointer to matrix A
double *B = b->buffer; //pointer to matrix B
double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM
double *ptr_b01_dup;
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
double* f_temp;
double ones = 1.0;
//scratch registers
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension
{
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4)
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
///GEMM code begins///
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
b01 += 1; //mobe to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
b01 += 1; //mobe to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
b01 += 1; //mobe to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
b01 += 1; //mobe to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM
}
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0]
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1]
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2]
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3]
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4]
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5]
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6]
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7]
///implement TRSM///
///transpose of B11//
///unpacklow///
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5]
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7]
//rearrange low elements
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5]
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7]
//rearrange high elements
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
ymm0 = _mm256_broadcast_sd((double const *)&ones);
//broadcast diagonal elements of A11
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1]
ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); //A11[2][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); //A11[3][3]
ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
ymm6 = _mm256_unpacklo_pd(ymm3, ymm4); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]
//extract a00
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
//extract a11
ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0]
a11 += cs_a;
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0]
ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0]
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4]
ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4]
ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1]
ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1]
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1]
a11 += cs_a;
//extract a22
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A110[][0] 1/A11[2][2] 1/A11[2][2]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
//(ROw2): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5]
//perform mul operation
ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /= A11[2][2]
ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2]
a11 += cs_a;
//extract a33
ymm1 = _mm256_permute_pd(ymm0, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11);//1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
//(ROw2): FMA operations
ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2]
ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6]
//perform mul operation
ymm11 = _mm256_mul_pd(ymm11, ymm1); //B11[0-3][3] /= A11[3][3]
ymm15 = _mm256_mul_pd(ymm15, ymm1); //B11[0-3][7] /= A11[3][3]
//unpacklow//
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
///unpack high///
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3]
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3]
}
if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR)
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4)
int iter;
if((j+D_NR) == n)
{
for(iter = 0; iter < m_remainder; iter++)
f_t[iter] = (b11 + cs_b * 7)[iter];
f_temp = f_t;
}
else
f_temp = (b11 + cs_b * 7);
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
///GEMM code Begins///
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
b01 += 1; //move to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] )
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
b01 += 1; //move to next row of B01
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
b01 += 1; //move to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
b01 += 1; //move to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0]
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1]
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2]
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3]
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4]
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5]
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6]
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7]
if(3 == m_remainder)
{
///implement TRSM///
///unpacklow///
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
//rearrange low elements
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
//rearrange high elements
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
ymm0 = _mm256_broadcast_sd((double const *)&ones);
//broadcast diagonal elements of A11
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1]
ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); //A11[2][2]
ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
ymm6 = _mm256_unpacklo_pd(ymm3, ymm0); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a00
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
//extract a11
ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
a11 += cs_a;
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0]
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4]
ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1]
ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1]
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
a11 += cs_a;
//extract a22
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
//(ROw2): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
//perform mul operation
ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /=A11[2][2]
ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
ymm15 = _mm256_broadcast_sd((double const *)(&ones));
//unpacklow//
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
///unpack high///
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
//determine correct values to store
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08);
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08);
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08);
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08);
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08);
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08);
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08);
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08);
}
else if(2 == m_remainder)
{
///implement TRSM///
///unpacklow///
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
//rearrange low elements
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
//rearrange high elements
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
ymm0 = _mm256_broadcast_sd((double const *)&ones);
//broadcast diagonal elements of A11
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1]
ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
ymm5 = _mm256_blend_pd(ymm5, ymm0, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a00
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
//extract a11
ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
a11 += cs_a;
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1]
ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1]
ymm10 = _mm256_broadcast_sd((double const *)&ones);
//unpacklow//
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
///unpack high///
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
//determine correct values to store
ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30);
ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30);
ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30);
ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30);
ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30);
ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30);
ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30);
ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30);
}
else if(1 == m_remainder)
{
///implement TRSM///
///unpacklow///
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
//rearrange low elements
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
//rearrange high elements
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
ymm0 = _mm256_broadcast_sd((double const *)&ones);
//broadcast diagonal elements of A11
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm0 = _mm256_div_pd(ymm0, ymm1); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a00
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(&ones));
//unpacklow//
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm5 = _mm256_unpacklo_pd(ymm12, ymm9); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1, ymm9, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1, ymm9, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm4 = _mm256_permute2f128_pd(ymm5, ymm9, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
ymm6 = _mm256_permute2f128_pd(ymm5, ymm9, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
///unpack high///
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm12 = _mm256_unpackhi_pd(ymm12, ymm9); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm5 = _mm256_permute2f128_pd(ymm12, ymm9, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
ymm7 = _mm256_permute2f128_pd(ymm12, ymm9, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
//determine correct values to store
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E);
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E);
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E);
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E);
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E);
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E);
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E);
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E);
}
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3])
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4])
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5])
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6])
_mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7])
if((j+D_NR) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * 7)[iter] = f_t[iter];
}
}
}
if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4)
{
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4)
///GEMM for previously calculated values ///
//load 4x4 block from b11
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7
///implement TRSM///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
//2nd col
a11 += cs_a;
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
//3rd col
a11 += cs_a;
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
//4th col
a11 += cs_a;
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2]
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
////unpacklow////
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
//rearrange low elements
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
//rearrange high elements
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
//extract a00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
//extract diag a11 from a
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3]
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3]
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3]
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
//extract diag a22 from a
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3]
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3]
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2]
//extract diag a33 from a
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3]
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3]
//--> Transpose and store results of columns of B block <--//
////unpacklow////
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
////unpackhigh////
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3])
}
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
dim_t iter;
if((j+4) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * 3)[iter];
}
else
f_temp = (b11 + cs_b * 3);
///GEMM for previously calculated values ///
//load 4x4 block from b11
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
for(k = 0; k < k_iter; k++) //looop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7
if(3 == m_remainder)
{
///implement TRSM///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
//2nd col
a11 += cs_a;
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
//3rd col
a11 += cs_a;
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
//4th col
a11 += cs_a;
ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
////unpacklow////
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
//rearrange low elements
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
//rearrange high elements
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
//extract a00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
//extract diag a11 from a
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3]
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
//extract diag a22 from a
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3]
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2]
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
//--> Transpose and store results of columns of B block <--//
////unpacklow////
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
////unpackhigh////
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
//load 4x4 block from b11
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
//determine correct values to store
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
}
else if( 2 == m_remainder )
{
///implement TRSM///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
//2nd col
a11 += cs_a;
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
//compute reciprocals of L(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
ymm4 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
////unpacklow////
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
//rearrange low elements
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
//rearrange high elements
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
//extract a00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
//extract diag a11 from a
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
//--> Transpose and store results of columns of B block <--//
////unpacklow////
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
////unpackhigh////
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
//load 4x4 block from b11
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
//determine correct values to store
ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30);
ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30);
ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30);
ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30);
}
else if(1 == m_remainder)
{
///implement TRSM///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
////unpacklow////
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
//rearrange low elements
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
//rearrange high elements
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
//extract a00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
ymm8 = _mm256_broadcast_sd((double const *)(&ones));
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
//--> Transpose and store results of columns of B block <--//
////unpacklow////
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
////unpackhigh////
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
//load 4x4 block from b11
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
//determine correct values to store
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E);
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E);
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E);
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3])
if((j+4) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * 3)[iter] = f_temp[iter];
}
}
n_remainder -= 4;
j += 4;
}
if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)
{
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM for previously calculated values ///
//load 4x4 block from b11
if(3 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
for(k = 0; k < k_iter; k++)
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
}
else if(2 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
for(k = 0; k < k_iter; k++)
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
}
else if(1 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
for(k = 0; k < k_iter; k++)
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5
ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
}
///implement TRSM///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
//2nd col
a11 += cs_a;
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
//3rd col
a11 += cs_a;
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
//4th col
a11 += cs_a;
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
////unpacklow////
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
//rearrange low elements
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
//rearrange high elements
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
//extract a00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
//extract diag a11 from a
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3]
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3]
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3]
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
//extract diag a22 from a
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3]
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3]
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2]
//extract diag a33 from a
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3]
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3]
//--> Transpose and store results of columns of B block <--//
////unpacklow////
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
////unpackhigh////
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
if(3 == n_remainder)
{
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
}
else if(2 == n_remainder)
{
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
}
else if(1 == n_remainder)
{
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
}
}
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of times GEMM operations to be performed
dim_t iter;
if((j+n_remainder) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter];
}
else
f_temp = (b11 + cs_b * (n_remainder -1));
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM for previously calculated values ///
//load 4x4 block from b11
if(3 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5
ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6
///implement TRSM///
//determine correct values to store
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2])
}
if(2 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5
///implement TRSM///
//determine correct values to store
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1])
}
if(n_remainder == 1)
{
ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
///implement TRSM///
//determine correct values to store
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
}
_mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0])
}
if((j+n_remainder) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
}
///scalar code for trsm without alpha///
dtrsm_small_AlXB(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
}
}
return BLIS_SUCCESS;
}
/* TRSM for the case AX = alpha * B, Double precision
* A is lower-triangular, no-transpose, unit diagonal
* dimensions A: mxm X: mxn B: mxn
b01--->
* *****************
** * * * * *
* * * * * * *
* * *b01* * * *
* * * * * * *
a10 ****** b11 *****************
| * * * | * * * * *
| * * * | * * * * *
| *a10*a11* | *b11* * * *
v * * * v * * * * *
*********** *****************
* * * * * * * * *
* * * * * * * * *
* * * * * * * * *
* * * * * * * * *
**************** *****************
a11--->
*/
static err_t bli_dtrsm_small_AlXB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 4; //size of block along 'M' dimpension
dim_t D_NR = 8; //size of block along 'N' dimension
dim_t m = bli_obj_length(b); // number of rows of matrix B
dim_t n = bli_obj_width(b); // number of columns of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME)
|| (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N)
|| (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n<D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t m_remainder = m & (3); //number of remainder rows
dim_t n_remainder = n & (7); //number of remainder columns
dim_t cs_a = bli_obj_col_stride(a); // column stride of A
dim_t cs_b = bli_obj_col_stride(b); // column stride of B
dim_t i, j, k; //loop variables
dim_t k_iter; //number of times GEMM to be performed
double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha
double *L = a->buffer; //pointer to matrix A
double *B = b->buffer; //pointer to matrix B
double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM
double *ptr_b01_dup;
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
double* f_temp;
double ones = 1.0;
//scratch registers
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension
{
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4)
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
///GEMM code begins///
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
b01 += 1; //mobe to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
b01 += 1; //mobe to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
b01 += 1; //mobe to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
b01 += 1; //mobe to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM
}
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0]
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1]
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2]
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3]
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4]
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5]
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6]
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7]
///implement TRSM///
///transpose of B11//
///unpacklow///
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5]
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7]
//rearrange low elements
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5]
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7]
//rearrange high elements
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0]
a11 += cs_a;
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0]
ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0]
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4]
ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4]
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1]
a11 += cs_a;
//(ROw2): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5]
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2]
a11 += cs_a;
//(ROw1): FMA operations
ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2]
ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6]
//unpacklow//
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
///unpack high///
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3]
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3]
_mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3]
}
if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR)
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4)
dim_t iter;
if((j+D_NR) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * 7)[iter];
}
else
f_temp = (b11 + cs_b * 7);
ymm8 = _mm256_setzero_pd();
ymm9 = _mm256_setzero_pd();
ymm10 = _mm256_setzero_pd();
ymm11 = _mm256_setzero_pd();
ymm12 = _mm256_setzero_pd();
ymm13 = _mm256_setzero_pd();
ymm14 = _mm256_setzero_pd();
ymm15 = _mm256_setzero_pd();
///GEMM code Begins///
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
b01 += 1; //move to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] )
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
b01 += 1; //move to next row of B01
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
b01 += 1; //move to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2])
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
b01 += 1; //move to next row of B
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0]
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1]
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2]
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3]
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4]
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5]
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6]
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7]
if(3 == m_remainder)
{
///implement TRSM///
///unpacklow///
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
//rearrange low elements
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
//rearrange high elements
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
a11 += cs_a;
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0]
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4]
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
a11 += cs_a;
//(ROw2): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
ymm15 = _mm256_broadcast_sd((double const *)(&ones));
//unpacklow//
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
///unpack high///
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
//determine correct values to store
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08);
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08);
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08);
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08);
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08);
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08);
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08);
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08);
}
else if(2 == m_remainder)
{
///implement TRSM///
///unpacklow///
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
//rearrange low elements
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
//rearrange high elements
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
a11 += cs_a;
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
ymm10 = _mm256_broadcast_sd((double const *)&ones);
//unpacklow//
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
///unpack high///
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
//determine correct values to store
ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30);
ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30);
ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30);
ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30);
ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30);
ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30);
ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30);
ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30);
}
else if(1 == m_remainder)
{
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
//determine correct values to store
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E);
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E);
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E);
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E);
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E);
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E);
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E);
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E);
}
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3])
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4])
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5])
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6])
_mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7])
if((j+D_NR) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * 7)[iter] = f_temp[iter];
}
}
}
if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4)
{
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4)
///GEMM for previously calculated values ///
//load 4x4 block from b11
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7
///implement TRSM///
//1st col
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
//2nd col
a11 += cs_a;
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
//3rd col
a11 += cs_a;
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
////unpacklow////
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
//rearrange low elements
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
//rearrange high elements
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3]
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3]
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3]
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3]
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3]
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3]
//--> Transpose and store results of columns of B block <--//
////unpacklow////
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
////unpackhigh////
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3])
}
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
dim_t iter;
if((j+4) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * 3)[iter];
}
else
f_temp = (b11 + cs_b * 3);
///GEMM for previously calculated values ///
//load 4x4 block from b11
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
for(k = 0; k < k_iter; k++) //looop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7
if(3 == m_remainder)
{
///implement TRSM///
//1st col
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
//2nd col
a11 += cs_a;
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
////unpacklow////
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
//rearrange low elements
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
//rearrange high elements
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3]
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3]
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
//--> Transpose and store results of columns of B block <--//
////unpacklow////
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
////unpackhigh////
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
//load 4x4 block from b11
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
//determine correct values to store
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
}
else if(2 == m_remainder)
{
///implement TRSM///
//1st col
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
////unpacklow////
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
//rearrange low elements
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
//rearrange high elements
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
//--> Transpose and store results of columns of B block <--//
////unpacklow////
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
////unpackhigh////
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
//load 4x4 block from b11
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
//determine correct values to store
ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30);
ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30);
ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30);
ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30);
}
else if(1 == m_remainder)
{
//load 4x4 block from b11
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
//determine correct values to store
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E);
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E);
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E);
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3])
if((j+4) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * 3)[iter] = f_temp[iter];
}
}
n_remainder -= 4;
j += 4;
}
if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)
{
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM for previously calculated values ///
//load 4x4 block from b11
if(3 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
for(k = 0; k < k_iter; k++)
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
}
else if(2 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
for(k = 0; k < k_iter; k++)
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
}
else if(1 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
for(k = 0; k < k_iter; k++)
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
b01 += 1;
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5
ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
}
///implement TRSM///
//1st col
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
//2nd col
a11 += cs_a;
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
//3rd col
a11 += cs_a;
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
////unpacklow////
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
//rearrange low elements
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
////unpackhigh////
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
//rearrange high elements
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3]
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3]
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3]
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3]
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3]
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3]
//--> Transpose and store results of columns of B block <--//
////unpacklow////
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
//rearrange low elements
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
////unpackhigh////
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
//rearrange high elements
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
if(3 == n_remainder)
{
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
}
else if(2 == n_remainder)
{
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
}
else if(1 == n_remainder)
{
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
}
}
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
{
a10 = L +i; //pointer to block of A to be used for GEMM
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
k_iter = i / D_MR; //number of times GEMM operations to be performed
dim_t iter;
if((j+n_remainder) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter];
}
else
f_temp = (b11 + cs_b * (n_remainder -1));
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM for previously calculated values ///
//load 4x4 block from b11
if(3 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5
ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6
///implement TRSM///
//determine correct values to store
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2])
}
else if(2 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5
///implement TRSM///
//determine correct values to store
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1])
}
else if(1 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_b01_dup = b01;
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
b01 += 1; //move to next row of B
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
///implement TRSM///
//determine correct values to store
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
}
_mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0])
}
if((j+n_remainder) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
}
///scalar code for trsm without alpha///
dtrsm_small_AlXB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
}
}
return BLIS_SUCCESS;
}
/*implements TRSM for the case XA = alpha * B
*A is upper triangular, non-unit diagonal, no transpose
*dimensions: X:mxn A:nxn B: mxn
*/
/* b11---> a01 ---->
***************** ***********
*b01*b11* * * * * * *
b11 * * * * * **a01 * * a11
| ***************** ********* |
| * * * * * *a11* * |
| * * * * * * * * |
v ***************** ****** v
* * * * * * *
* * * * * * *
***************** * *
*
*/
static err_t bli_dtrsm_small_XAuB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 8; //block dimension along the rows
dim_t D_NR = 4; //block dimension along the columns
dim_t m = bli_obj_length(b); //number of rows
dim_t n = bli_obj_width(b); //number of columns
dim_t m_remainder = m & 7; //number of corner rows
dim_t n_remainder = n & 3; //number of corner columns
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME)
|| (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t i, j, k; //loop variablse
dim_t k_iter; //determines the number of GEMM operations to be done
dim_t cs_b_offset[2]; //pre-calculated strides
double ones = 1.0;
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
double *L = a->buffer; //pointer to matrix A
double *B = b->buffer; //pointer to matrix B
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
double *ptr_a01_dup;
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
double* f_temp;
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
//ymm scratch reginsters
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
{
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
{
a01 = L + j*cs_a; //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used in GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A01
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
//load 8x4 block of B11
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += cs_a;
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
//3rd col
a11 += cs_a;
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
//4th col
a11 += cs_a;
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
//extract a00
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
//extract a11
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
//extract a22
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
//extract a33
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
//(Row3)FMA operations
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3]
ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
}
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
///load 4x4 block of b11
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
//subtract the calculated GEMM block from current TRSM block
//load 8x4 block of B11
if(3 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
///GEMM code ends///
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += cs_a;
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
//3rd col
a11 += cs_a;
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
//4th col
a11 += cs_a;
ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
//extract a00
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
//extract a11
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
//extract a22
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
}
else if(2 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
///GEMM code ends///
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += cs_a;
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
//compute reciprocals of L(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] 1 1
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/1 1/1)
//extract a00
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
//extract a11
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
}
else if(1 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
///GEMM code ends///
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
//extract a00
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
}
}
}
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
{
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += cs_a;
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
//3rd col
a11 += cs_a;
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
//4th col
a11 += cs_a;
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
//(Row3)FMA operations
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
if(3 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM for previous blocks ///
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += cs_a;
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));//A11[1][1]
//3rd col
a11 += cs_a;
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm6 = _mm256_unpacklo_pd(ymm9, ymm14); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
}
else if(2 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM for previous blocks ///
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += cs_a;
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));//A11[1][1]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm15 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
}
else if(1 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM for previous blocks ///
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm14 = _mm256_div_pd(ymm14, ymm4); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
}
}
m_remainder -= 4;
i += 4;
}
if(m_remainder) ///omplementation for remainder rows
{
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
dim_t iter;
if((j+D_NR) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * 3)[iter];
}
else
f_temp = (b11 + cs_b * 3);
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM implementation ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += cs_a;
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
//3rd col
a11 += cs_a;
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
//4th col
a11 += cs_a;
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
//(Row3)FMA operations
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
ymm4 = _mm256_loadu_pd((double const *)(b11));
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
ymm7 = _mm256_loadu_pd((double const *)f_temp);
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3])
if((j+D_NR) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * 3)[iter] = f_temp[iter];
}
}
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
dim_t iter;
if((j+n_remainder) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter];
}
else
f_temp = (b11 + cs_b * (n_remainder-1));
///GEMM for previous blocks ///
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///load 4x4 block of b11
if(3 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)f_temp, ymm2); //(store(B11[x][2]))
}
else if(2 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1])
}
else if(1 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)f_temp); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
}
_mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0])
}
if((j+n_remainder) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
}
//scalar code for TRSM
dtrsm_small_XAuB(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
}
}
return BLIS_SUCCESS;
}
/*implements TRSM for the case XA = alpha * B
*A is upper triangular, unit-diagonal, no transpose
*dimensions: X:mxn A:nxn B: mxn
*/
/* b11---> a01 ---->
***************** ***********
*b01*b11* * * * * * *
b11 * * * * * **a01 * * a11
| ***************** ********* |
| * * * * * *a11* * |
| * * * * * * * * |
v ***************** ****** v
* * * * * * *
* * * * * * *
***************** * *
*
*/
static err_t bli_dtrsm_small_XAuB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 8; //block dimension along the rows
dim_t D_NR = 4; //block dimension along the columns
dim_t m = bli_obj_length(b); //number of rows
dim_t n = bli_obj_width(b); //number of columns
dim_t m_remainder = m & 7; //number of corner rows
dim_t n_remainder = n & 3; //number of corner columns
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME)
|| (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t i, j, k; //loop variablse
dim_t k_iter; //determines the number of GEMM operations to be done
dim_t cs_b_offset[2]; //pre-calculated strides
double ones = 1.0;
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
double *L = a->buffer; //pointer to matrix A
double *B = b->buffer; //pointer to matrix B
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
double *ptr_a01_dup;
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
double* f_temp;
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
//ymm scratch reginsters
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
{
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
{
a01 = L + j*cs_a; //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used in GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A01
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
//load 8x4 block of B11
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//2nd col
a11 += cs_a;
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
//3rd col
a11 += cs_a;
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
//4th col
a11 += cs_a;
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
//(Row2)FMA operations
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
//(Row3)FMA operations
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
}
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
///load 4x4 block of b11
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
//subtract the calculated GEMM block from current TRSM block
//load 8x4 block of B11
if(3 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
///GEMM code ends///
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//2nd col
a11 += cs_a;
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
//3rd col
a11 += cs_a;
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
//(Row2)FMA operations
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
}
else if(2 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//2nd col
a11 += cs_a;
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
}
else if(1 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
}
}
}
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
{
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += cs_a;
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
//3rd col
a11 += cs_a;
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
//4th col
a11 += cs_a;
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
//(Row3)FMA operations
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
if(3 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM for previous blocks ///
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += cs_a;
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
//3rd col
a11 += cs_a;
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
}
else if(2 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM for previous blocks ///
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += cs_a;
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
}
else if(1 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM for previous blocks ///
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
}
}
m_remainder -= 4;
i += 4;
}
if(m_remainder) ///omplementation for remainder rows
{
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
dim_t iter;
if((j+D_NR) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b_offset[1])[iter];
}
else
f_temp = (b11 + cs_b_offset[1]);
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM implementation ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += cs_a;
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
//3rd col
a11 += cs_a;
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
//4th col
a11 += cs_a;
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
//(Row3)FMA operations
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
ymm4 = _mm256_loadu_pd((double const *)(b11));
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3])
if((j+D_NR) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b_offset[1])[iter] = f_temp[iter];
}
}
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
{
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
dim_t iter;
if((j+n_remainder) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter];
}
else
f_temp = (b11 + cs_b * (n_remainder-1));
///GEMM for previous blocks ///
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///load 4x4 block of b11
if(3 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)f_temp, ymm2); //(store(B11[x][2]))
}
if(2 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1])
}
if(1 == n_remainder)
{
ymm0 = _mm256_loadu_pd((double const *)f_temp); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
}
_mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0])
}
if((j+n_remainder) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
}
//scalar code for TRSM
dtrsm_small_XAuB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
}
}
return BLIS_SUCCESS;
}
/*implements TRSM for the case XA = alpha * B
*A is lower triangular, non-unit diagonal, transpose
*dimensions: X:mxn A:nxn B: mxn
*/
/* b11---> a01 ---->
***************** ***********
*b01*b11* * * * * * *
b11 * * * * * **a01 * * a11
| ***************** ********* |
| * * * * * *a11* * |
| * * * * * * * * |
v ***************** ****** v
* * * * * * *
* * * * * * *
***************** * *
*
*/
static err_t bli_dtrsm_small_XAltB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 8; //block dimension along the rows
dim_t D_NR = 4; //block dimension along the columns
dim_t m = bli_obj_length(b); //number of rows
dim_t n = bli_obj_width(b); //number of columns
dim_t m_remainder = m & 7; //number of corner rows
dim_t n_remainder = n & 3; //number of corner columns
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N)
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N)
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME)
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME)
|| (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t i, j, k; //loop variablse
dim_t k_iter; //determines the number of GEMM operations to be done
dim_t cs_b_offset[2]; //pre-calculated strides
double ones = 1.0;
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
double *L = a->buffer; //pointer to matrix A
double *B = b->buffer; //pointer to matrix B
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
double *ptr_a01_dup;
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
double* f_temp;
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
//ymm scratch reginsters
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
{
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
{
a01 = L + j; //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used in GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
//load 8x4 block of B11
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += 1;
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
//3rd col
a11 += 1;
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
//4th col
a11 += 1;
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//extract a00
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
//extract a11
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
//extract a22
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
//extract a33
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
//(Row3)FMA operations
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3]
ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
}
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
///load 4x4 block of b11
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
//subtract the calculated GEMM block from current TRSM block
//load 8x4 block of B11
if(3 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
a01 += cs_a; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
a01 += cs_a; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += 1;
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
//3rd col
a11 += 1;
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
//4th col
a11 += 1;
ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//extract a00
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
//extract a11
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
//extract a22
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
}
else if(2 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
a01 += cs_a; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
a01 += cs_a; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm8 = _mm256_loadu_pd((double const *)b11);
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR));
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b));
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR));
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += 1;
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
//compute reciprocals of L(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
//extract a00
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
//extract a11
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
}
else if(1 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
a01 += cs_a; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
a01 += cs_a; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
///implement TRSM///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
ymm8 = _mm256_mul_pd(ymm8, ymm7); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm7); //B11[4-7][0] /= A11[0][0]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
}
}
}
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
{
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st row
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//2nd row
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//3rd row
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//4th row
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
//(Row3)FMA operations
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
if(3 == n_remainder)
{
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
///implement TRSM///
///read 4x4 block of A11///
//1st row
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
a11 += cs_a;//move to next column
//2nd row
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
a11 += cs_a;//move to next column
//3rd row
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
a11 += cs_a;//move to next column
//4th row
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
}
else if(2 == n_remainder)
{
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
///implement TRSM///
///read 4x4 block of A11///
//1st row
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
a11 += cs_a;//move to next column
//2nd row
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
a11 += cs_a;//move to next column
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm15 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
}
else if(1 == n_remainder)
{
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
///implement TRSM///
///read 4x4 block of A11///
//1st row
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
ymm14 = _mm256_broadcast_sd((double const *)&ones);
ymm14 = _mm256_div_pd(ymm14, ymm4); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
ymm0 = _mm256_mul_pd(ymm0, ymm14); //B11[x][0] /= A11[0][0]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
}
}
m_remainder -= 4;
i += 4;
}
if(m_remainder) ///omplementation for remainder rows
{
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
dim_t iter;
if((j+D_NR) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b_offset[1])[iter];
}
else
f_temp = (b11 + cs_b_offset[1]);
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM implementation ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st row
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//2nd row
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//3rd row
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//4th row
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
//(Row3)FMA operations
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
ymm4 = _mm256_loadu_pd((double const *)(b11));
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3])
if((j+D_NR) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b_offset[1])[iter] = f_temp[iter];
}
}
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
dim_t iter;
if((j+n_remainder) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter];
}
else
f_temp = (b11 + cs_b * (n_remainder-1));
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM for previous blocks ///
if(3 == n_remainder)
{
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(f_temp), ymm2); //(store(B11[x][2]))
}
else if(2 == n_remainder)
{
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1])
}
else if(1 == n_remainder)
{
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
}
_mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0])
}
if((j+n_remainder) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
}
//scalar code for TRSM
dtrsm_small_XAltB(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
}
}
return BLIS_SUCCESS;
}
/*implements TRSM for the case XA = alpha * B
*A is lower triangular, unit-diagonal, transpose
*dimensions: X:mxn A:nxn B: mxn
*/
/* b11---> a01 ---->
***************** ***********
*b01*b11* * * * * * *
b11 * * * * * **a01 * * a11
| ***************** ********* |
| * * * * * *a11* * |
| * * * * * * * * |
v ***************** ****** v
* * * * * * *
* * * * * * *
***************** * *
*
*/
static err_t bli_dtrsm_small_XAltB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 8; //block dimension along the rows
dim_t D_NR = 4; //block dimension along the columns
dim_t m = bli_obj_length(b); //number of rows
dim_t n = bli_obj_width(b); //number of columns
dim_t m_remainder = m & 7; //number of corner rows
dim_t n_remainder = n & 3; //number of corner columns
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N)
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N)
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME)
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME)
|| (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t i, j, k; //loop variablse
dim_t k_iter; //determines the number of GEMM operations to be done
dim_t cs_b_offset[2]; //pre-calculated strides
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
double *L = a->buffer; //pointer to matrix A
double *B = b->buffer; //pointer to matrix B
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
double *ptr_a01_dup;
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
double* f_temp;
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
//ymm scratch reginsters
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
{
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
{
a01 = L + j; //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used in GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
//load 8x4 block of B11
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += 1;
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
//3rd col
a11 += 1;
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
//4th col
a11 += 1;
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
//(Row2)FMA operations
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
//(Row3)FMA operations
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
}
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
///load 4x4 block of b11
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
//subtract the calculated GEMM block from current TRSM block
//load 8x4 block of B11
if(3 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
a01 += cs_a; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
a01 += cs_a; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += 1;
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
//3rd col
a11 += 1;
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
//4th col
a11 += 1;
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
//(Row2)FMA operations
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
}
else if(2 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
a01 += cs_a; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
a01 += cs_a; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm8 = _mm256_loadu_pd((double const *)b11);
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR));
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b));
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR));
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += 1;
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
//(Row1): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
}
else if(1 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
a01 += cs_a; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
a01 += cs_a; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
}
}
}
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
{
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st row
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//2nd row
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//3rd row
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
//(Row3)FMA operations
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
if(3 == n_remainder)
{
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
///implement TRSM///
///read 4x4 block of A11///
//1st row
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
a11 += cs_a;//move to next column
//2nd row
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
a11 += cs_a;//move to next column
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
}
else if(2 == n_remainder)
{
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
///implement TRSM///
///read 4x4 block of A11///
//1st row
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
a11 += cs_a;//move to next column
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
}
else if(1 == n_remainder)
{
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
}
}
m_remainder -= 4;
i += 4;
}
if(m_remainder) ///omplementation for remainder rows
{
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
dim_t iter;
if((j+D_NR) == n)
{
f_temp = f_t;
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b_offset[1])[iter];
}
else
f_temp = (b11 + cs_b_offset[1]);
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM implementation ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st row
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//2nd row
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//3rd row
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
a11 += cs_a;//move to next column
//(Row1): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
//(Row2)FMA operations
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
//(Row3)FMA operations
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
ymm4 = _mm256_loadu_pd((double const *)(b11));
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3])
if((j+D_NR) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b_offset[1])[iter] = f_temp[iter];
}
}
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
{
a01 = L + j; //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
dim_t iter;
err_t r_val;
if((j+n_remainder) == n)
{
f_temp = bli_malloc_user(4 * sizeof(double), &r_val);
for(iter = 0; iter < m_remainder; iter++)
f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter];
}
else
f_temp = (b11 + cs_b * (n_remainder-1));
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM for previous blocks ///
if(3 == n_remainder)
{
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(f_temp), ymm2); //(store(B11[x][2]))
}
else if(2 == n_remainder)
{
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
}
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1])
}
else if(1 == n_remainder)
{
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM implementation ends
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
///implement TRSM///
if(3 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
}
else if(2 == m_remainder)
{
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
}
else if(1 == m_remainder)
{
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
}
_mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0])
}
if((j+n_remainder) == n)
{
for(iter = 0; iter < m_remainder; iter++)
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
}
//scalar code for TRSM
dtrsm_small_XAltB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
}
}
return BLIS_SUCCESS;
}
/*implements TRSM for the case XA = alpha * B
*A is lower triangular, non-unit diagonal, no transpose
*dimensions: X:mxn A:nxn B: mxn
*/
/* <---b11 <---a11
***************** *
*b01*b11* * * * *
^ * * * * * ^ * *
| ***************** | *******
| * * * * * | * * *
| * * * * * a01* * *
b10 ***************** *************
* * * * * * * * *
* * * * * * * * *
***************** *******************
*/
static err_t bli_dtrsm_small_XAlB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 8; //block dimension along the rows
dim_t D_NR = 4; //block dimension along the columns
dim_t m = bli_obj_length(b); //number of rows
dim_t n = bli_obj_width(b); //number of columns
dim_t m_remainder = m & 7; //number of corner rows
dim_t n_remainder = n & 3; //number of corner columns
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME)
||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t i, j, k; //loop variablse
dim_t k_iter; //determines the number of GEMM operations to be done
dim_t cs_b_offset[2]; //pre-calculated strides
double ones = 1.0;
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
double* restrict L = a->buffer; //pointer to matrix A
double* restrict B = b->buffer; //pointer to matrix B
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
double *ptr_a01_dup;
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
//ymm scratch reginsters
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
{
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
{
a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A01
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
//load 8x4 block of B11
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
//2nd col
a11 += 1;
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
//3rd col
a11 += 1;
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
//4th col
a11 += 1;
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//extract a33
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
//extract a22
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(row 3):FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
//extract a11
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row 2): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
//extract a00
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
//(Row 1): FMA operations
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
}
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
{
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
///load 4x4 block of b11
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
//subtract the calculated GEMM block from current TRSM block
//load 8x4 block of B11
if(3 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0]
//2nd col
a11 += 1;
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
//3rd col
a11 += 1;
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
//4th col
a11 += 1;
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//extract a33
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
//extract a22
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(row 3):FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
//extract a11
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(Row 2): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
ymm9 = _mm256_mul_pd(ymm9, ymm0);
ymm13 = _mm256_mul_pd(ymm13, ymm0);
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
else if(2 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//3rd col
a11 += 2;
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
//4th col
a11 += 1;
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//extract a33
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
//extract a22
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(row 3):FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm10 = _mm256_mul_pd(ymm10, ymm0);
ymm14 = _mm256_mul_pd(ymm14, ymm0);
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
else if(1 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
///implement TRSM///
///read 4x4 block of A11///
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
//4th col
a11 += 3;
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
//compute reciprocals of L(i,i) and broadcast in registers
ymm7 = _mm256_div_pd(ymm7, ymm6); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
ymm11 = _mm256_mul_pd(ymm11, ymm7);
ymm15 = _mm256_mul_pd(ymm15, ymm7);
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
}
}
if(i<0)
i += D_NR;
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
{
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
{
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0]
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0]
//2nd col
a11 += cs_a;
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
//3rd col
a11 += cs_a;
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
//4th col
a11 += cs_a;
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm3 = _mm256_mul_pd(ymm3, ymm15);
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
ymm2 = _mm256_mul_pd(ymm2, ymm15);
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(ROW 2): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
ymm1 = _mm256_mul_pd(ymm1, ymm15);
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
//(Row 1):FMA operations
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
ymm0 = _mm256_mul_pd(ymm0, ymm15);
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
{
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM for previous blocks ///
if(3 == n_remainder)
{
///load 4x4 block of b11
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += cs_a;
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
//3rd col
a11 += cs_a;
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
//4th col
a11 += cs_a;
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm14, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm3 = _mm256_mul_pd(ymm3, ymm15);
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
ymm2 = _mm256_mul_pd(ymm2, ymm15);
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(ROW 2): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
ymm1 = _mm256_mul_pd(ymm1, ymm15);
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
}
else if(2 == n_remainder)
{
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//3rd col
a11 += 2 * cs_a;
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
//4th col
a11 += cs_a;
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm3 = _mm256_mul_pd(ymm3, ymm15);
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm2 = _mm256_mul_pd(ymm2, ymm15);
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
}
else if(1 == n_remainder)
{
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//4th col
a11 += 3 * cs_a;
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a33
ymm3 = _mm256_mul_pd(ymm3, ymm14);
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
}
}
m_remainder -= 4;
i -= 4;
}
// if(i < 0) i = 0;
if(m_remainder) ///implementation for remainder rows
{
dtrsm_small_XAlB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
}
return BLIS_SUCCESS;
}
/*implements TRSM for the case XA = alpha * B
*A is lower triangular, unit-diagonal, no transpose
*dimensions: X:mxn A:nxn B: mxn
*/
/* <---b11 <---a11
***************** *
*b01*b11* * * * *
^ * * * * * ^ * *
| ***************** | *******
| * * * * * | * * *
| * * * * * a01* * *
b10 ***************** *************
* * * * * * * * *
* * * * * * * * *
***************** *******************
*/
static err_t bli_dtrsm_small_XAlB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 8; //block dimension along the rows
dim_t D_NR = 4; //block dimension along the columns
dim_t m = bli_obj_length(b); //number of rows
dim_t n = bli_obj_width(b); //number of columns
dim_t m_remainder = m & 7; //number of corner rows
dim_t n_remainder = n & 3; //number of corner columns
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME)
||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t i, j, k; //loop variablse
dim_t k_iter; //determines the number of GEMM operations to be done
dim_t cs_b_offset[2]; //pre-calculated strides
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
double* restrict L = a->buffer; //pointer to matrix A
double* restrict B = b->buffer; //pointer to matrix B
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
double *ptr_a01_dup;
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
//ymm scratch reginsters
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
{
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
{
a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A01
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
//load 8x4 block of B11
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += 1;
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
//3rd col
a11 += 1;
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
//4th col
a11 += 1;
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//(row 3):FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
//(Row 2): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
//(Row 1): FMA operations
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
}
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
{
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
///load 4x4 block of b11
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
//subtract the calculated GEMM block from current TRSM block
//load 8x4 block of B11
if(3 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
///implement TRSM///
///read 4x4 block of A11///
//3rd col
a11 += 2;
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
//4th col
a11 += 1;
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//(row 3):FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
//(Row 2): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
else if(2 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
///implement TRSM///
///read 4x4 block of A11///
//4th col
a11 += 3;
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
//(row 3):FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
else if(1 == n_remainder)
{
///GEMM implementation begins///
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
}
}
if(i<0)
i += D_NR;
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
{
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
{
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0]
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0]
//2nd col
a11 += cs_a;
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
//3rd col
a11 += cs_a;
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
//(ROW 2): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
//(Row 1):FMA operations
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
{
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM for previous blocks ///
if(3 == n_remainder)
{
///load 4x4 block of b11
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//2nd col
a11 += cs_a;
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
//3rd col
a11 += cs_a;
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
//(ROW 2): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
}
else if(2 == n_remainder)
{
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//3rd col
a11 += 2 * cs_a;
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
}
else if(1 == n_remainder)
{
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM processing stars///
for(k = 0; k < k_iter; k++)
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
a01 += 1; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
a01 += 1; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
a01 += 1; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
a01 += 1; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
}
}
m_remainder -= 4;
i -= 4;
}
// if(i < 0) i = 0;
if(m_remainder) ///implementation for remainder rows
{
dtrsm_small_XAlB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
}
return BLIS_SUCCESS;
}
/*implements TRSM for the case XA = alpha * B
*A is lower triangular, non-unit diagonal, no transpose
*dimensions: X:mxn A:nxn B: mxn
*/
/* <---b11 <---a11
***************** *
*b01*b11* * * * *
^ * * * * * ^ * *
| ***************** | *******
| * * * * * | * * *
| * * * * * a01* * *
b10 ***************** *************
* * * * * * * * *
* * * * * * * * *
***************** *******************
*/
static err_t bli_dtrsm_small_XAutB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 8; //block dimension along the rows
dim_t D_NR = 4; //block dimension along the columns
dim_t m = bli_obj_length(b); //number of rows
dim_t n = bli_obj_width(b); //number of columns
dim_t m_remainder = m & 7; //number of corner rows
dim_t n_remainder = n & 3; //number of corner columns
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME)
||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t i, j, k; //loop variablse
dim_t k_iter; //determines the number of GEMM operations to be done
dim_t cs_b_offset[2]; //pre-calculated strides
double ones = 1.0;
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
double* restrict L = a->buffer; //pointer to matrix A
double* restrict B = b->buffer; //pointer to matrix B
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
double *ptr_a01_dup;
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
//ymm scratch reginsters
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
{
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
{
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
//load 8x4 block of B11
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
a11 += cs_a;
//2nd col
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
a11 += cs_a;
//3rd col
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
a11 += cs_a;
//4th col
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
ymm7 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
//extract a33
ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm11 = _mm256_mul_pd(ymm11, ymm7);
ymm15 = _mm256_mul_pd(ymm15, ymm7);
//extract a22
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row 3): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
//(Row 3): FMA operations
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
ymm10 = _mm256_mul_pd(ymm10, ymm7);
ymm14 = _mm256_mul_pd(ymm14, ymm7);
//extract a11
ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(ROW 2): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
ymm9 = _mm256_mul_pd(ymm9, ymm7);
ymm13 = _mm256_mul_pd(ymm13, ymm7);
//extract A00
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
//(Row 1):FMA operations
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
ymm8 = _mm256_mul_pd(ymm8, ymm7);
ymm12 = _mm256_mul_pd(ymm12, ymm7);
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
{
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
//load 8x4 block of B11
if(3 == n_remainder)
{
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0]
a11 += cs_a;
//2nd col
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
a11 += cs_a;
//3rd col
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
a11 += cs_a;
//4th col
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
ymm7 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
//extract a33
ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm11 = _mm256_mul_pd(ymm11, ymm7);
ymm15 = _mm256_mul_pd(ymm15, ymm7);
//extract a22
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row 3): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
//(Row 3): FMA operations
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
ymm10 = _mm256_mul_pd(ymm10, ymm7);
ymm14 = _mm256_mul_pd(ymm14, ymm7);
//extract a11
ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(ROW 2): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
ymm9 = _mm256_mul_pd(ymm9, ymm7);
ymm13 = _mm256_mul_pd(ymm13, ymm7);
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
else if(2 == n_remainder)
{
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
a11 += 2 * cs_a;
//3rd col
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
a11 += cs_a;
//4th col
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
ymm7 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
//extract a33
ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm11 = _mm256_mul_pd(ymm11, ymm7);
ymm15 = _mm256_mul_pd(ymm15, ymm7);
//extract a22
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row 3): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
//(Row 3): FMA operations
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm10 = _mm256_mul_pd(ymm10, ymm7);
ymm14 = _mm256_mul_pd(ymm14, ymm7);
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
else if(1 == n_remainder)
{
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
a11 += 3 * cs_a;
//4th col
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
ymm7 = _mm256_broadcast_sd((double const *)&ones);
ymm0 = _mm256_div_pd(ymm7, ymm6); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
ymm11 = _mm256_mul_pd(ymm11, ymm0);
ymm15 = _mm256_mul_pd(ymm15, ymm0);
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
}
}
if(i<0)
i += D_NR;
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
{
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
{
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
a11 += cs_a;
//2nd col
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
a11 += cs_a;
//3rd col
ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
a11 += cs_a;
//4th col
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm3 = _mm256_mul_pd(ymm3, ymm15);
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
ymm2 = _mm256_mul_pd(ymm2, ymm15);
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(ROW 2): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
ymm1 = _mm256_mul_pd(ymm1, ymm15);
//extract A00
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
//(Row 1):FMA operations
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
ymm0 = _mm256_mul_pd(ymm0, ymm15);
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
{
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
///GEMM for previous blocks ///
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///load 4x4 block of b11
if(3 == n_remainder)
{
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
ymm4 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0]
a11 += cs_a;
//2nd col
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
a11 += cs_a;
//3rd col
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
a11 += cs_a;
//4th col
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm3 = _mm256_mul_pd(ymm3, ymm15);
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
ymm2 = _mm256_mul_pd(ymm2, ymm15);
//extract a11
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
//(ROW 2): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
ymm1 = _mm256_mul_pd(ymm1, ymm15);
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
}
else if(2 == n_remainder)
{
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
a11 += 2 * cs_a;
//3rd col
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
a11 += cs_a;
//4th col
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
//extract a33
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
ymm3 = _mm256_mul_pd(ymm3, ymm15);
//extract a22
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm2 = _mm256_mul_pd(ymm2, ymm15);
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
}
else if(1 == n_remainder)
{
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
a11 += 3 * cs_a;
//4th col
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
ymm14 = _mm256_broadcast_sd((double const *)&ones);
//compute reciprocals of A(i,i) and broadcast in registers
ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
ymm3 = _mm256_mul_pd(ymm3, ymm14);
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
}
}
m_remainder -= 4;
i -= 4;
}
if(m_remainder) ///implementation for remainder rows
{
dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
}
return BLIS_SUCCESS;
}
/*implements TRSM for the case XA = alpha * B
*A is lower triangular, unit-diagonal, no transpose
*dimensions: X:mxn A:nxn B: mxn
*/
/* <---b11 <---a11
***************** *
*b01*b11* * * * *
^ * * * * * ^ * *
| ***************** | *******
| * * * * * | * * *
| * * * * * a01* * *
b10 ***************** *************
* * * * * * * * *
* * * * * * * * *
***************** *******************
*/
static err_t bli_dtrsm_small_XAutB_unitDiag(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
dim_t D_MR = 8; //block dimension along the rows
dim_t D_NR = 4; //block dimension along the columns
dim_t m = bli_obj_length(b); //number of rows
dim_t n = bli_obj_width(b); //number of columns
dim_t m_remainder = m & 7; //number of corner rows
dim_t n_remainder = n & 3; //number of corner columns
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME)
||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N)
)
return BLIS_NOT_YET_IMPLEMENTED;
#else
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
#endif
dim_t i, j, k; //loop variablse
dim_t k_iter; //determines the number of GEMM operations to be done
dim_t cs_b_offset[2]; //pre-calculated strides
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
double* restrict L = a->buffer; //pointer to matrix A
double* restrict B = b->buffer; //pointer to matrix B
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
double *ptr_a01_dup;
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
//ymm scratch reginsters
__m256d ymm0, ymm1, ymm2, ymm3;
__m256d ymm4, ymm5, ymm6, ymm7;
__m256d ymm8, ymm9, ymm10, ymm11;
__m256d ymm12, ymm13, ymm14, ymm15;
__m256d ymm16;
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
{
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
{
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
//load 8x4 block of B11
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
//1st col
a11 += cs_a;
//2nd col
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
a11 += cs_a;
//3rd col
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
a11 += cs_a;
//4th col
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
//(Row 3): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
//(Row 3): FMA operations
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
//(ROW 2): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
//(Row 1):FMA operations
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
{
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
ymm0 = _mm256_setzero_pd();
ymm1 = _mm256_setzero_pd();
ymm2 = _mm256_setzero_pd();
ymm3 = _mm256_setzero_pd();
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
//load 8x4 block of B11
if(3 == n_remainder)
{
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
a11 += 2 * cs_a;
//3rd col
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
a11 += cs_a;
//4th col
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
//(Row 3): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
//(Row 3): FMA operations
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
//(ROW 2): FMA operations
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
else if(2 == n_remainder)
{
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
///implement TRSM///
///read 4x4 block of A11///
a11 += 3 * cs_a;
//4th col
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
//(Row 3): FMA operations
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
//(Row 3): FMA operations
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
else if(1 == n_remainder)
{
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//broadcast 1st row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row
//load 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
//broadcast 2nd row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
//broadcast 3rd row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A01
//load next 8x2 block of B10
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
//broadcast 4th row of A01
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A01
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code ends///
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
}
}
}
if(i<0)
i += D_NR;
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
{
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
{
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
///GEMM for previous blocks ///
///load 4x4 block of b11
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
a11 += cs_a;
//2nd col
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
a11 += cs_a;
//3rd col
ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
a11 += cs_a;
//4th col
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
//(ROW 2): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
//(Row 1):FMA operations
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
}
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
{
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
///GEMM for previous blocks ///
ymm4 = _mm256_setzero_pd();
ymm5 = _mm256_setzero_pd();
ymm6 = _mm256_setzero_pd();
ymm7 = _mm256_setzero_pd();
///load 4x4 block of b11
if(3 == n_remainder)
{
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
a11 += 2 * cs_a;
//3rd col
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
a11 += cs_a;
//4th col
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
//(ROW 2): FMA operations
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
}
else if(2 == n_remainder)
{
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
///implement TRSM///
///read 4x4 block of A11///
a11 += 3 * cs_a;
//4th col
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
//(Row 3): FMA operations
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
}
else if(1 == n_remainder)
{
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
///GEMM implementation starts///
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
{
ptr_a01_dup = a01;
//load 4x4 bblock of b10
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
//broadcast 1st row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
a01 += cs_a; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
//broadcast 2nd row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
a01 += cs_a; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
//braodcast 3rd row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
a01 += cs_a; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
//broadcast 4th row of A01
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
a01 += cs_a; //move to next row of A
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
}
///GEMM code end///
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
}
}
m_remainder -= 4;
i -= 4;
}
if(m_remainder) ///implementation for remainder rows
{
dtrsm_small_XAutB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
}
return BLIS_SUCCESS;
}
/*
* AX = Alpha*B, Single precision, A: lower triangular
* This kernel implementation supports matrices A and B such that m is equal to BLI_AlXB_M_SP and n is mutiple of 8
*/
static err_t bli_strsm_small_AlXB (
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
obj_t alpha, beta; // gemm parameters
obj_t Ga, Gb, Gc; // for GEMM
int m = bli_obj_length(b); // number of rows of matrix B
int n = bli_obj_width(b); // number of columns of matrix B
int lda = bli_obj_col_stride(a); // column stride of A
int ldb = bli_obj_col_stride(b); // column stride of B
int rsa = bli_obj_row_stride(a); // row stride of A
int rsb = bli_obj_row_stride(b); // row stride of B
int i = 0;
int j;
int blk_size = 8;
int isUnitDiag = bli_obj_has_unit_diag(a);
float alphaVal;
float* restrict L = a->buffer;
float* restrict B = b->buffer;
if (m != BLI_AlXB_M_SP || (n&7) != 0)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
if ( (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM )
{
return BLIS_NOT_YET_IMPLEMENTED;
}
alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj));
/* Small _GEMM preparation code */
bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &alpha );
bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &beta );
/* B = B - A*B */
bli_setsc( -(1.0), 0.0, &alpha );
bli_setsc( (1.0), 0.0, &beta );
bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, blk_size, a->buffer, rsa, lda, &Ga);
bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gb);
bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gc);
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Ga );
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gb );
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gc );
//first block of trsm
Gb.buffer = (void*)(B + i);
//trsm of first 8xn block
if (alphaVal != 1)
{
if (isUnitDiag == 0)
{
blis_strsm_microkernel_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
fp_blis_strsm_microkernel = blis_strsm_microkernel;
}
else
{
blis_strsm_microkernel_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag;
}
bli_setsc( alphaVal, 0.0, &beta );
}
else
{
if (isUnitDiag == 0)
{
blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
fp_blis_strsm_microkernel = blis_strsm_microkernel;
}
else
{
blis_strsm_microkernel_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag;
}
}
//gemm update
for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT
{
Ga.buffer = (void*)(L + j + i*lda);
Gc.buffer = (void*)(B + j);
bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb
}
//trsm of remaining blocks
for (i = blk_size; i < m; i += blk_size)
{
Gb.buffer = (void*)(B + i);
fp_blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT
{
Ga.buffer = (void*)(L + j + i*lda);
Gc.buffer = (void*)(B + j);
bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb
}
} // End of for loop - i
return BLIS_SUCCESS;
}
/*
* XA' = Alpha*B, Single precision, A: lower triangular
* This kernel implementation supports matrices A and B such that
* m and n are multiples of 8 and n is less than or equal to BLI_XAltB_N_SP
*/
static err_t bli_strsm_small_XAltB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
int m = bli_obj_length(a); // number of rows of matrix B
int n = bli_obj_length(b); // number of columns of matrix B
int lda = bli_obj_col_stride(a); // column stride of A
int ldb = bli_obj_col_stride(b); // column stride of B
int rsa = bli_obj_row_stride(a); // row stride of A
int rsb = bli_obj_row_stride(b); // row stride of B
int i = 0;
int isUnitDiag = bli_obj_has_unit_diag(a);
float alphaVal;
float *L = a->buffer;
float *B = b->buffer;
if ((m&7) != 0 || (n&7) != 0)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
if ( n > BLI_XAltB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM )
{
return BLIS_NOT_YET_IMPLEMENTED;
}
alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj));
if (alphaVal != 1)
{
if (isUnitDiag == 0)
{
trsm_XAtB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
}
else
{
trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
}
}
else
{
if (isUnitDiag == 0)
{
trsm_XAtB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
}
else
{
trsm_XAtB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
}
}
return BLIS_SUCCESS;
}
/*
* A'X = Alpha*B, Single precision, A: upper triangular
* This kernel implementation supports matrices A and B such that
* m and n are multiples of 8, m is less than or equal to BLI_AutXB_M_SP and n is less than or equal to BLI_AutXB_N_SP
*/
static err_t bli_strsm_small_AutXB(
side_t side,
obj_t* AlphaObj,
obj_t* a,
obj_t* b,
cntx_t* cntx,
cntl_t* cntl
)
{
int m = bli_obj_width(a); // number of rows of matrix A (since At, so width is taken)
int n = bli_obj_width(b); // number of columns of matrix B
int lda = bli_obj_col_stride(a); // column stride of A
int ldb = bli_obj_col_stride(b); // column stride of B
int rsa = bli_obj_row_stride(a); // row stride of A
int rsb = bli_obj_row_stride(b); // row stride of B
int i = 0;
int isUnitDiag = bli_obj_has_unit_diag(a);
float alphaVal;
float *L = a->buffer;
float *B = b->buffer;
if ((m&7) != 0 || (n&7) != 0)
{
return BLIS_NOT_YET_IMPLEMENTED;
}
if ( m > BLI_AutXB_M_SP || n > BLI_AutXB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM )
{
return BLIS_NOT_YET_IMPLEMENTED;
}
alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj));
if (alphaVal != 1)
{
if (isUnitDiag == 0)
{
trsm_AutXB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
}
else
{
trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
}
}
else
{
if (isUnitDiag == 0)
{
trsm_AutXB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
}
else
{
trsm_AutXB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
}
}
return BLIS_SUCCESS;
}
///////////////////////////// AX=B ///////////////////////////////
static void blis_strsm_microkernel_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal)
{
float ones = 1.0;
int j;
int cs_b_offset[6];
//int row2, row4, row6;
float *ptr_b_dup;
//70 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_cols[8];
__m256 mat_a_cols_rearr[36];
__m256 mat_a_diag_inv[8];
__m256 reciprocal_diags;
__m256 alphaReg;
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
alphaReg = _mm256_broadcast_ss((float const *)&alphaVal);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
//row2 = (cs_l << 1);
//row4 = (cs_l << 2);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
//row6 = row2 + row4;
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
//1st col
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
//2nd col
ptr_l += cs_l;
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//3rd col
ptr_l += cs_l;
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//4rth col
ptr_l += cs_l;
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//5th col
ptr_l += cs_l;
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//6th col
ptr_l += cs_l;
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//7th col
ptr_l += cs_l;
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//7th col
ptr_l += cs_l;
mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
numCols_b -= 8; // blk_width = 8
//compute reciprocals of L(i,i) and broadcast in registers
mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
//reciprocal of diagnol elements
reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
//Start loop for cols of B to be processed in size of blk_width
for (j = 0; j < numCols_b; j += 8)
{
ptr_b_dup = ptr_b;
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
//--> Transpose and store results of columns of B block <--//
////unpacklow////
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
#else
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
//Read next set of B columns
ptr_b += (cs_b + cs_b_offset[5]);
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
//end loop of cols
}
//Last block trsm processing
ptr_b_dup = ptr_b;
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
//--> Transpose and store results of columns of B block <--//
////unpacklow////
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
#else
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
//end loop of cols
}
static void blis_strsm_microkernel_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal)
{
//float ones = 1.0;
int j;
int cs_b_offset[6];
//int row2, row4, row6;
float *ptr_b_dup;
//70 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_cols[8];
__m256 mat_a_cols_rearr[36];
//__m256 mat_a_diag_inv[8];
//__m256 reciprocal_diags;
__m256 alphaReg;
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
//reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
alphaReg = _mm256_broadcast_ss((float const *)&alphaVal);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
//row2 = (cs_l << 1);
//row4 = (cs_l << 2);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
//row6 = row2 + row4;
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
//1st col
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
//2nd col
ptr_l += cs_l;
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//3rd col
ptr_l += cs_l;
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//4rth col
ptr_l += cs_l;
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//5th col
ptr_l += cs_l;
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//6th col
ptr_l += cs_l;
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//7th col
ptr_l += cs_l;
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//8th col
//ptr_l += cs_l;
//mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
numCols_b -= 8; // blk_width = 8
//compute reciprocals of L(i,i) and broadcast in registers
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
//mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
//mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
//reciprocal of diagnol elements
//reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
//Start loop for cols of B to be processed in size of blk_width
for (j = 0; j < numCols_b; j += 8)
{
ptr_b_dup = ptr_b;
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//extract diag a00 from a
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
//extract diag a11 from a
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//extract diag a22 from a
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//extract diag a33 from a
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//extract diag a44 from a
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//extract diag a55 from a
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//extract diag a66 from a
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//extract diag a77 from a
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
//--> Transpose and store results of columns of B block <--//
////unpacklow////
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
#else
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
//Read next set of B columns
ptr_b += (cs_b + cs_b_offset[5]);
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
//end loop of cols
}
//Last block trsm processing
ptr_b_dup = ptr_b;
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//extract diag a00 from a
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
//extract diag a11 from a
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//extract diag a22 from a
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//extract diag a33 from a
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//extract diag a44 from a
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//extract diag a55 from a
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//extract diag a66 from a
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//extract diag a77 from a
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
//--> Transpose and store results of columns of B block <--//
////unpacklow////
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
#else
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
//end loop of cols
}
static void blis_strsm_microkernel_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
{
//float ones = 1.0;
int j;
int cs_b_offset[6];
//int row2, row4, row6;
float *ptr_b_dup;
//70 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_cols[8];
__m256 mat_a_cols_rearr[36];
//__m256 mat_a_diag_inv[8];
//__m256 reciprocal_diags;
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
//reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
//row2 = (cs_l << 1);
//row4 = (cs_l << 2);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
//row6 = row2 + row4;
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
//1st col
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
//2nd col
ptr_l += cs_l;
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//3rd col
ptr_l += cs_l;
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//4rth col
ptr_l += cs_l;
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//5th col
ptr_l += cs_l;
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//6th col
ptr_l += cs_l;
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//7th col
ptr_l += cs_l;
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//8th col
//ptr_l += cs_l;
//mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
numCols_b -= 8; // blk_width = 8
//compute reciprocals of L(i,i) and broadcast in registers
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
//mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
//mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
//reciprocal of diagnol elements
//reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
//Start loop for cols of B to be processed in size of blk_width
for (j = 0; j < numCols_b; j += 8)
{
ptr_b_dup = ptr_b;
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//extract diag a00 from a
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
//extract diag a11 from a
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//extract diag a22 from a
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//extract diag a33 from a
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//extract diag a44 from a
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//extract diag a55 from a
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//extract diag a66 from a
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//extract diag a77 from a
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
//--> Transpose and store results of columns of B block <--//
////unpacklow////
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
#else
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
//Read next set of B columns
ptr_b += (cs_b + cs_b_offset[5]);
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
//end loop of cols
}
//Last block trsm processing
ptr_b_dup = ptr_b;
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//extract diag a00 from a
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
//extract diag a11 from a
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//extract diag a22 from a
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//extract diag a33 from a
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//extract diag a44 from a
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//extract diag a55 from a
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//extract diag a66 from a
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//extract diag a77 from a
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
//--> Transpose and store results of columns of B block <--//
////unpacklow////
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
#else
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
//end loop of cols
}
static void blis_strsm_microkernel(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
{
float ones = 1.0;
int j;
int cs_b_offset[6];
//int row2, row4, row6;
float *ptr_b_dup;
//70 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_cols[8];
__m256 mat_a_cols_rearr[36];
__m256 mat_a_diag_inv[8];
__m256 reciprocal_diags;
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
//row2 = (cs_l << 1);
//row4 = (cs_l << 2);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
//row6 = row2 + row4;
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
ptr_l += cs_l;
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
//1st col
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
//2nd col
ptr_l += cs_l;
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//3rd col
ptr_l += cs_l;
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//4rth col
ptr_l += cs_l;
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//5th col
ptr_l += cs_l;
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//6th col
ptr_l += cs_l;
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//7th col
ptr_l += cs_l;
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//7th col
ptr_l += cs_l;
mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
numCols_b -= 8; // blk_width = 8
//compute reciprocals of L(i,i) and broadcast in registers
mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
//reciprocal of diagnol elements
reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
//Start loop for cols of B to be processed in size of blk_width
for (j = 0; j < numCols_b; j += 8)
{
ptr_b_dup = ptr_b;
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
//--> Transpose and store results of columns of B block <--//
////unpacklow////
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
#else
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
//Read next set of B columns
ptr_b += (cs_b + cs_b_offset[5]);
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
//end loop of cols
}
//Last block trsm processing
ptr_b_dup = ptr_b;
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
//--> Transpose and store results of columns of B block <--//
////unpacklow////
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
#else
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
//end loop of cols
}
#if OPT_CACHE_BLOCKING_L1 //new intrinsic kernels
static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
{
float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l, r;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup, *ptr_l_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_blk_elems[8];
__m256 mat_a_diag_inv[8];
__m256 reciprocal_diags[2];
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
//read diag elems of L 16x16 block
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
reciprocal_diags[1] = reciprocal_diags[0];
//pack first 8 diags together
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
//i += cs_b_offset[6];
//ptr_b_dup += cs_b_offset[6];
i += 8;
ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += 8;
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += cs_b_offset[6];
i1 += cs_b_offset[6];
//Read next 8x8 block of A to get diag elements
i3 += cs_l_offset[6];
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
//pack 8 diags of A together
reciprocal_diags[0] = reciprocal_diags[1];
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
{
#if GEMM_ACCUM_A
i = i1 + r;
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
#endif
i = 0;
i2 = 0;
for (l = 0; l < j; l += 8) // move across m
{
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
{
/////////////////// Partial Lower 8x8 block trsm of B
ptr_l_dup = ptr_l;
i4 = i2 + r;
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
i4 = k >> 3;
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#endif
//end loop of cols
}
i2 += cs_b_offset[6];
i += cs_l_offset[6];
}
//trsm solve
k = 0;
//for (i2 = 0; i2 < numCols_b; i2 += 8)
{
i2 = i1 + r;
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
#if !GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
#endif
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
#if GEMM_ACCUM_A
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
#else
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
#endif
#if GEMM_ACCUM_A
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A32 to A72 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A43 to A73 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A54 to A74 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A65 to A75 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A76 to register
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
}
}
} //numRows of A
///////////////////loop ends /////////////////////
}
static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
{
float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l, r;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup, *ptr_l_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_blk_elems[8];
__m256 mat_a_diag_inv[8];
__m256 reciprocal_diags[2];
__m256 alphaReg;
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
//read diag elems of L 16x16 block
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
reciprocal_diags[1] = reciprocal_diags[0];
//pack first 8 diags together
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
#if 0
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
#endif
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
//i += cs_b_offset[6];
//ptr_b_dup += cs_b_offset[6];
i += 8;
ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += 8;
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += cs_b_offset[6];
i1 += cs_b_offset[6];
//Read next 8x8 block of A to get diag elements
i3 += cs_l_offset[6];
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
//pack 8 diags of A together
reciprocal_diags[0] = reciprocal_diags[1];
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
{
#if GEMM_ACCUM_A
i = i1 + r;
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
#endif
i = 0;
i2 = 0;
for (l = 0; l < j; l += 8) // move across m
{
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
{
/////////////////// Partial Lower 8x8 block trsm of B
ptr_l_dup = ptr_l;
i4 = i2 + r;
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
i4 = k >> 3;
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#endif
//end loop of cols
}
i2 += cs_b_offset[6];
i += cs_l_offset[6];
}
//trsm solve
k = 0;
//for (i2 = 0; i2 < numCols_b; i2 += 8)
{
i2 = i1 + r;
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
#if !GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
#endif
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
#if GEMM_ACCUM_A
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
#else
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
#endif
#if GEMM_ACCUM_A
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A32 to A72 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A43 to A73 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A54 to A74 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A65 to A75 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A76 to register
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
}
}
} //numRows of A
///////////////////loop ends /////////////////////
}
static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
{
//float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l, r;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup, *ptr_l_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_blk_elems[8];
//__m256 mat_a_diag_inv[8];
//__m256 reciprocal_diags[2];
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
//(Row0)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
//i += cs_b_offset[6];
//ptr_b_dup += cs_b_offset[6];
i += 8;
ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += 8;
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += cs_b_offset[6];
i1 += cs_b_offset[6];
i3 += cs_l_offset[6];
i = 0;
i2 = 0;
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
{
#if GEMM_ACCUM_A
i = i1 + r;
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
#endif
i = 0;
i2 = 0;
for (l = 0; l < j; l += 8) // move across m
{
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
{
/////////////////// Partial Lower 8x8 block trsm of B
ptr_l_dup = ptr_l;
i4 = i2 + r;
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
i4 = k >> 3;
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#endif
//end loop of cols
}
i2 += cs_b_offset[6];
i += cs_l_offset[6];
}
//trsm solve
k = 0;
//for (i2 = 0; i2 < numCols_b; i2 += 8)
{
i2 = i1 + r;
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
#if !GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
#endif
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
#if GEMM_ACCUM_A
//(Row0): already done
#else
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
#endif
#if GEMM_ACCUM_A
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A32 to A72 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A43 to A73 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A54 to A74 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A65 to A75 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A76 to register
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
}
}
} //numRows of A
///////////////////loop ends /////////////////////
}
static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
{
//float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l, r;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup, *ptr_l_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_blk_elems[8];
//__m256 mat_a_diag_inv[8];
//__m256 reciprocal_diags[2];
__m256 alphaReg;
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
#if 0
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
#endif
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
//(Row0)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
//i += cs_b_offset[6];
//ptr_b_dup += cs_b_offset[6];
i += 8;
ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += 8;
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += cs_b_offset[6];
i1 += cs_b_offset[6];
i3 += cs_l_offset[6];
i = 0;
i2 = 0;
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
{
#if GEMM_ACCUM_A
i = i1 + r;
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
#endif
i = 0;
i2 = 0;
for (l = 0; l < j; l += 8) // move across m
{
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
{
/////////////////// Partial Lower 8x8 block trsm of B
ptr_l_dup = ptr_l;
i4 = i2 + r;
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
i4 = k >> 3;
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
ptr_l_dup += cs_l;
#if GEMM_ACCUM_A
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#endif
//end loop of cols
}
i2 += cs_b_offset[6];
i += cs_l_offset[6];
}
//trsm solve
k = 0;
//for (i2 = 0; i2 < numCols_b; i2 += 8)
{
i2 = i1 + r;
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
#if !GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
#endif
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
#if GEMM_ACCUM_A
//(Row0): already done
#else
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
#endif
#if GEMM_ACCUM_A
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A32 to A72 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A43 to A73 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A54 to A74 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A65 to A75 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A76 to register
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
}
}
} //numRows of A
///////////////////loop ends /////////////////////
}
#else //rel 1.0 intrisic kernels (NOT OPT_CACHE_BLOCKING_L1)
static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
{
float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[16][8];
__m256 mat_a_cols_rearr[8];
__m256 mat_a_blk_elems[64];
__m256 mat_a_diag_inv[8];
__m256 reciprocal_diags[2];
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
//read diag elems of L 16x16 block
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l);
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
reciprocal_diags[1] = reciprocal_diags[0];
//pack first 8 diags together
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
//i += cs_b_offset[6];
//ptr_b_dup += cs_b_offset[6];
i += 8;
ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += 8;
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += cs_b_offset[6];
i1 += cs_b_offset[6];
//Read next 8x8 block of A to get diag elements
i3 += cs_l_offset[6];
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
//pack 8 diags of A together
reciprocal_diags[0] = reciprocal_diags[1];
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
i = 0;
i2 = 0;
for (k = 0; k < numCols_b; k += 8)
{
i = i1 + k;
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
i2++;
}
i = 0;
i2 = 0;
for (l = 0; l < j; l += 8) // move across m
{
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
// _mm256_permute2f128_ps()
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
i += cs_l_offset[6];
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
{
/////////////////// Partial Lower 8x8 block trsm of B
i4 = i2 + k;
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
i4 = k >> 3;
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
//end loop of cols
}
i2 += cs_b_offset[6];
}
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
k = 0;
for (i = 0; i < numCols_b; i+=8)
{
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
}
}
///////////////////loop ends /////////////////////
}
static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
{
float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[16][8];
__m256 mat_a_cols_rearr[8];
__m256 mat_a_blk_elems[64];
__m256 mat_a_diag_inv[8];
__m256 reciprocal_diags[2];
__m256 alphaReg;
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
//read diag elems of L 16x16 block
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l);
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
reciprocal_diags[1] = reciprocal_diags[0];
//pack first 8 diags together
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg);
mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg);
mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg);
mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg);
mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg);
mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg);
mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg);
mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
//i += cs_b_offset[6];
//ptr_b_dup += cs_b_offset[6];
i += 8;
ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += 8;
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += cs_b_offset[6];
i1 += cs_b_offset[6];
//Read next 8x8 block of A to get diag elements
i3 += cs_l_offset[6];
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
//pack 8 diags of A together
reciprocal_diags[0] = reciprocal_diags[1];
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
i = 0;
i2 = 0;
for (k = 0; k < numCols_b; k += 8)
{
i = i1 + k;
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg);
mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg);
mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg);
mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg);
mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg);
mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg);
mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg);
mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg);
i2++;
}
i = 0;
i2 = 0;
for (l = 0; l < j; l += 8) // move across m
{
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
// _mm256_permute2f128_ps()
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
i += cs_l_offset[6];
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
{
/////////////////// Partial Lower 8x8 block trsm of B
i4 = i2 + k;
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
i4 = k >> 3;
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
//end loop of cols
}
i2 += cs_b_offset[6];
}
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
k = 0;
for (i = 0; i < numCols_b; i+=8)
{
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
k++;
}
}
///////////////////loop ends /////////////////////
}
static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
{
//float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[16][8];
//__m256 mat_a_cols_rearr[8];
__m256 mat_a_blk_elems[64];
//__m256 mat_a_diag_inv[8];
//__m256 reciprocal_diags[2];
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
//(Row0)
mat_b_col[0] = mat_b_rearr[0][0];
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
//i += cs_b_offset[6];
//ptr_b_dup += cs_b_offset[6];
i += 8;
ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += 8;
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += cs_b_offset[6];
i1 += cs_b_offset[6];
i3 += cs_l_offset[6];
i = 0;
i2 = 0;
for (k = 0; k < numCols_b; k += 8)
{
i = i1 + k;
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
i2++;
}
i = 0;
i2 = 0;
for (l = 0; l < j; l += 8) // move across m
{
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
// _mm256_permute2f128_ps()
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
i += cs_l_offset[6];
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
{
/////////////////// Partial Lower 8x8 block trsm of B
i4 = i2 + k;
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
i4 = k >> 3;
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
//end loop of cols
}
i2 += cs_b_offset[6];
}
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
k = 0;
for (i = 0; i < numCols_b; i+=8)
{
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
//(Row0): already done
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
}
}
///////////////////loop ends /////////////////////
}
static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
{
//float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[16][8];
//__m256 mat_a_cols_rearr[8];
__m256 mat_a_blk_elems[64];
//__m256 mat_a_diag_inv[8];
//__m256 reciprocal_diags[2];
__m256 alphaReg;
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg);
mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg);
mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg);
mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg);
mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg);
mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg);
mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg);
mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg);
//(Row0)
mat_b_col[0] = mat_b_rearr[0][0];
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
//i += cs_b_offset[6];
//ptr_b_dup += cs_b_offset[6];
i += 8;
ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += 8;
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += cs_b_offset[6];
i1 += cs_b_offset[6];
i3 += cs_l_offset[6];
i = 0;
i2 = 0;
for (k = 0; k < numCols_b; k += 8)
{
i = i1 + k;
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg);
mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg);
mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg);
mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg);
mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg);
mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg);
mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg);
mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg);
i2++;
}
i = 0;
i2 = 0;
for (l = 0; l < j; l += 8) // move across m
{
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
// _mm256_permute2f128_ps()
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
i += cs_l_offset[6];
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
{
/////////////////// Partial Lower 8x8 block trsm of B
i4 = i2 + k;
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
i4 = k >> 3;
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
//end loop of cols
}
i2 += cs_b_offset[6];
}
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
i += cs_l;
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
k = 0;
for (i = 0; i < numCols_b; i+=8)
{
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
//(Row0): already done
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
}
}
///////////////////loop ends /////////////////////
}
#endif //OPT_CACHE_BLOCKING_L1
//////////////////////////// AutX=B ///////////////////////
static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
{
float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l, r;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup, *ptr_l_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_blk_elems[8];
__m256 mat_a_diag_inv[8];
__m256 reciprocal_diags[2];
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
//read diag elems of L 16x16 block
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
reciprocal_diags[1] = reciprocal_diags[0];
//pack first 8 diags together
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
#if 0
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
#endif
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
/* transpose steps start */
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
/* transpose steps end */
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
i += cs_b_offset[6];
ptr_b_dup += cs_b_offset[6];
//i += 8;
//ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += cs_l_offset[6];
//Read next 8x8 block of A to get diag elements
i3 += 8;
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
//pack 8 diags of A together
reciprocal_diags[0] = reciprocal_diags[1];
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += 8;
i1 += 8;
i = i1;
i2 = 0;
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
{
#if GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
/* transpose steps end */
#endif
//i = 0;
ptr_l_dup = ptr_l;
i4 = i2;
for (l = 0; l < j; l += 8) // move across m
{
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
//{
/////////////////// Partial Lower 8x8 block trsm of B
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
#else
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
/* transpose steps end */
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
//i4 = k >> 3;
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#endif
//end loop of cols
//}
//i2 += cs_b_offset[6];
i4 += 8;
}
//trsm solve
k = 0;
//for (i2 = 0; i2 < numCols_b; i2 += 8)
//{
//i2 = i1 + r;
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
#if !GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
#endif
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
//i += cs_l;
#if GEMM_ACCUM_A
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
#else
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
#endif
#if GEMM_ACCUM_A
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A32 to A72 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A43 to A73 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A54 to A74 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A65 to A75 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A76 to register
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
//}
i += cs_b_offset[6];
i2 += cs_b_offset[6];
}
} //numRows of A
///////////////////loop ends /////////////////////
}
static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
{
float ones = 1.0;
int i, i1, i2, i3, i4, j, k, l, r;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup, *ptr_l_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_blk_elems[8];
__m256 mat_a_diag_inv[8];
__m256 reciprocal_diags[2];
__m256 alphaReg;
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
//read diag elems of L 16x16 block
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
reciprocal_diags[1] = reciprocal_diags[0];
//pack first 8 diags together
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
#if 0
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
#endif
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
/* transpose steps start */
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
/* transpose steps end */
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
i += cs_b_offset[6];
ptr_b_dup += cs_b_offset[6];
//i += 8;
//ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i3 = 0;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += cs_l_offset[6];
//Read next 8x8 block of A to get diag elements
i3 += 8;
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
//pack 8 diags of A together
reciprocal_diags[0] = reciprocal_diags[1];
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += 8;
i1 += 8;
i = i1;
i2 = 0;
//extract diag a00 from a
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
//extract diag a11 from a
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
//extract diag a22 from a
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
//extract diag a33 from a
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
//extract diag a44 from a
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
//extract diag a55 from a
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
//extract diag a66 from a
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
//extract diag a77 from a
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
{
#if GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
/* transpose steps end */
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
#endif
//i = 0;
ptr_l_dup = ptr_l;
i4 = i2;
for (l = 0; l < j; l += 8) // move across m
{
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
//{
/////////////////// Partial Lower 8x8 block trsm of B
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
#else
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
/* transpose steps end */
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
//i4 = k >> 3;
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#endif
//end loop of cols
//}
//i2 += cs_b_offset[6];
i4 += 8;
}
//trsm solve
k = 0;
//for (i2 = 0; i2 < numCols_b; i2 += 8)
//{
//i2 = i1 + r;
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
#if !GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
#endif
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
//i += cs_l;
#if GEMM_ACCUM_A
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
#else
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
#endif
#if GEMM_ACCUM_A
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A32 to A72 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A43 to A73 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A54 to A74 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A65 to A75 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
//i += cs_l;
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A76 to register
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
////////////////////////////////////////////////////////////////////////////////
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
//}
i += cs_b_offset[6];
i2 += cs_b_offset[6];
}
} //numRows of A
///////////////////loop ends /////////////////////
}
static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
{
//float ones = 1.0;
int i, i1, i2, i4, j, k, l, r;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup, *ptr_l_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_blk_elems[8];
//__m256 mat_a_diag_inv[8];
//__m256 reciprocal_diags[2];
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
#if 0
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
#endif
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
//(Row0)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
/* transpose steps start */
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
/* transpose steps end */
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
i += cs_b_offset[6];
ptr_b_dup += cs_b_offset[6];
//i += 8;
//ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += cs_l_offset[6];
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += 8;
i1 += 8;
i = i1;
i2 = 0;
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
{
#if GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
/* transpose steps end */
#endif
//i = 0;
ptr_l_dup = ptr_l;
i4 = i2;
for (l = 0; l < j; l += 8) // move across m
{
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
//{
/////////////////// Partial Lower 8x8 block trsm of B
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
#else
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
/* transpose steps end */
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
//i4 = k >> 3;
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#endif
//end loop of cols
//}
//i2 += cs_b_offset[6];
i4 += 8;
}
//trsm solve
k = 0;
//for (i2 = 0; i2 < numCols_b; i2 += 8)
//{
//i2 = i1 + r;
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
#if !GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
#endif
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
//i += cs_l;
#if GEMM_ACCUM_A
//(Row0): already done
#else
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
#endif
#if GEMM_ACCUM_A
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
//i += cs_l;
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A32 to A72 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
//i += cs_l;
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A43 to A73 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
//i += cs_l;
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A54 to A74 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
//i += cs_l;
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A65 to A75 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
//i += cs_l;
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A76 to register
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
//}
i += cs_b_offset[6];
i2 += cs_b_offset[6];
}
} //numRows of A
///////////////////loop ends /////////////////////
}
static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
{
//float ones = 1.0;
int i, i1, i2, i4, j, k, l, r;
int cs_b_offset[7];
int cs_l_offset[7];
float *ptr_b_dup, *ptr_l_dup;
//57 number of ymm(256 bits) registers used
__m256 mat_b_col[8];
__m256 mat_b_rearr[8];
__m256 mat_a_blk_elems[8];
//__m256 mat_a_diag_inv[8];
//__m256 reciprocal_diags[2];
__m256 alphaReg;
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
//L matrix offsets
cs_l_offset[0] = (cs_l << 1);
cs_l_offset[1] = cs_l + cs_l_offset[0];
cs_l_offset[2] = (cs_l << 2);
cs_l_offset[3] = cs_l + cs_l_offset[2];
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
cs_l_offset[5] = cs_l + cs_l_offset[4];
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
cs_b_offset[0] = (cs_b << 1);
cs_b_offset[1] = cs_b + cs_b_offset[0];
cs_b_offset[2] = (cs_b << 2);
cs_b_offset[3] = cs_b + cs_b_offset[2];
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
cs_b_offset[5] = cs_b + cs_b_offset[4];
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
#if 0
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
//Broadcast A21 to A71 to registers
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
//Broadcast A32 to A72 to registers
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
//Broadcast A43 to A73 to registers
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
//Broadcast A54 to A74 to registers
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
//Broadcast A65 to A75 to registers
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
//Broadcast A76 to register
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
#endif
/***************** first set of 8 rows of B processing starts *****************/
ptr_b_dup = ptr_b;
i = 0;
for (j = 0; j < numCols_b; j += 8)
{
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
//read 8x8 block of B into registers
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
//(Row0)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
/* transpose steps start */
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
/* transpose steps end */
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
i += cs_b_offset[6];
ptr_b_dup += cs_b_offset[6];
//i += 8;
//ptr_b_dup += 8;
}
//c = 0;
/***************** first set of 8 cols of B processing done *****************/
ptr_b_dup = ptr_b;
i1 = 0;
//Start loop for cols of B to be processed in size of blk_width
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
{
ptr_l += cs_l_offset[6];
//ptr_b += j;
//ptr_b_dup += 8;
ptr_b_dup += 8;
i1 += 8;
i = i1;
i2 = 0;
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
{
#if GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
////unpackhigh////
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
/* transpose steps end */
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
#endif
//i = 0;
ptr_l_dup = ptr_l;
i4 = i2;
for (l = 0; l < j; l += 8) // move across m
{
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
//{
/////////////////// Partial Lower 8x8 block trsm of B
//Read current 8 cols of B columns from specified 8x8 current-block of B
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
#else
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
/* transpose steps end */
//Broadcast A8,0 to A15,0 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
//i4 = k >> 3;
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,2 to A15,2 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,3 to A15,3 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,4 to A15,4 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,5 to A15,5 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,6 to A15,6 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A8,7 to A15,7 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
ptr_l_dup++;
#if GEMM_ACCUM_A
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
#endif
//end loop of cols
//}
//i2 += cs_b_offset[6];
i4 += 8;
}
//trsm solve
k = 0;
//for (i2 = 0; i2 < numCols_b; i2 += 8)
//{
//i2 = i1 + r;
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
#if !GEMM_ACCUM_A
//Read 8 cols of B columns of Block-to-be-solved
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
#endif
//Broadcast A10 to A70 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
//i += cs_l;
#if GEMM_ACCUM_A
//(Row0): already done
#else
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
#endif
#if GEMM_ACCUM_A
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#else
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
#endif
//Broadcast A21 to A71 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
//i += cs_l;
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A32 to A72 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
//i += cs_l;
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A43 to A73 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
//i += cs_l;
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A54 to A74 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
//i += cs_l;
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A65 to A75 to registers
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
//i += cs_l;
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
//Broadcast A76 to register
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
//(Row7): FMA operations of b7 with elements of index (7, 0)
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
////////////////////////////////////////////////////////////////////////////////
/* transpose steps start */
////unpacklow////
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange low elements
#if REARRANGE_SHFL == 1
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
#else
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
#endif
//Merge rearranged low elements into complete rows
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
////unpackhigh////
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
//Rearrange high elements
#if REARRANGE_SHFL == 1
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
#else
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
#endif
//Merge rearranged high elements into complete rows
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
/* transpose steps end */
//Store the computed B columns
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
k++;
//}
i += cs_b_offset[6];
i2 += cs_b_offset[6];
}
} //numRows of A
///////////////////loop ends /////////////////////
}
#endif