mirror of
https://github.com/amd/blis.git
synced 2026-05-04 14:31:12 +00:00
Updated copyright information for kernels/zen/bli_trsm_small.c file Removed separate kernels for zen2 architecture Instead added threshold conditions in zen kernels both for ROME and NAPLES Change-Id: Ifd715731741d649b6ad16b123a86dbd6665d97e5
25102 lines
1.5 MiB
25102 lines
1.5 MiB
/*
|
|
|
|
BLIS
|
|
An object-based framework for developing high-performance BLAS-like
|
|
libraries.
|
|
|
|
Copyright (C) 2018-2019, Advanced Micro Devices, Inc.
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
- Neither the name of The University of Texas at Austin nor the names
|
|
of its contributors may be used to endorse or promote products
|
|
derived from this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*/
|
|
|
|
#include "blis.h"
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
|
|
#include "immintrin.h"
|
|
#define GEMM_BLK_V1 8 //Block size to perform gemm and apply trsm
|
|
#define GEMM_ACCUM_A 1 //Peform B1=B1-(B0*A0) operation instead of B1'=(B0*A0) and then B1=B1-B1'
|
|
#define OPT_CACHE_BLOCKING_L1 1 //Perform trsm block-wise in blocks of GEMM_BLK_V1 instead of all columns of B together.
|
|
#define REARRANGE_SHFL 0 //Rearrange operations using blend or shuffle
|
|
#define BLI_AlXB_M_SP 16
|
|
#define BLI_XAltB_N_SP 128
|
|
#define BLI_AutXB_M_SP 64
|
|
#define BLI_AutXB_N_SP 128
|
|
|
|
// XA = B; A is lower-traingular; No transpose; double precision; non-unit diagonal
|
|
static err_t bli_dtrsm_small_XAlB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is lower triabgular; No transpose; double precision; unit-diagonal
|
|
static err_t bli_dtrsm_small_XAlB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is lower-triangular; A is transposed; double precision; non-unit-diagonal
|
|
static err_t bli_dtrsm_small_XAltB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is lower-triangular; A is transposed; double precision; unit-diagonal
|
|
static err_t bli_dtrsm_small_XAltB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
// XA = B; A is upper triangular; No transpose; double presicion; non-unit diagonal
|
|
static err_t bli_dtrsm_small_XAuB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is upper triangular; No transpose; double precision; unit-diagonal
|
|
static err_t bli_dtrsm_small_XAuB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is upper-triangular; A is transposed; double precision; non-unit diagonal
|
|
static err_t bli_dtrsm_small_XAutB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is upper-triangular; A is transposed; double precision; unit diagonal
|
|
static err_t bli_dtrsm_small_XAutB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//AX = B; A is lower triangular; No transpose; double precision; non-unit diagonal
|
|
static err_t bli_dtrsm_small_AlXB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//AX = B; A is lower triangular; No transpose; double precision; unit diagonal
|
|
static err_t bli_dtrsm_small_AlXB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
|
|
|
|
static void (*fp_blis_strsm_microkernel)( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
static void blis_strsm_microkernel( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
static void blis_strsm_microkernel_alpha( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alphaVal
|
|
);
|
|
static void blis_strsm_microkernel_unitDiag( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
static void blis_strsm_microkernel_alpha_unitDiag( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alphaVal
|
|
);
|
|
static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alphaVal);
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alphaVal);
|
|
|
|
|
|
static void blis_dtrsm_microkernel( double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
|
|
static void blis_dtrsm_microkernel_alpha( double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
double alphaVal
|
|
);
|
|
|
|
static void blis_dtrsm_microkernel_unitDiag( double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
|
|
static void blis_dtrsm_microkernel_alpha_unitDiag( double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
double alphaVal
|
|
);
|
|
|
|
static void dtrsm_XAtB_block_allSmallSizedMatrices(double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha(double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
double alphaVal);
|
|
static void dtrsm_XAtB_block_allSmallSizedMatrices_unitDiag(double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
double alphaVal);
|
|
static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alpha);
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alpha);
|
|
|
|
//AX = B; A is lower triangular; No transpose; single precision
|
|
static err_t bli_strsm_small_AlXB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
//A.'X = B; A is upper triangular; A has to be transposed; single precision
|
|
static err_t bli_strsm_small_AutXB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA.' = B; A is lower triangular; A has to be transposed; single precision
|
|
static err_t bli_strsm_small_XAltB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//A.'X = B; A is upper triangular; A has to be transposed; double precision
|
|
static err_t bli_dtrsm_small_AutXB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
/*
|
|
* The bli_trsm_small implements unpacked version of TRSM
|
|
* Currently only column-major is supported, A & B are column-major
|
|
* Input: A: MxM (triangular matrix)
|
|
* B: MxN matrix
|
|
* Output: X: MxN matrix such that AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B
|
|
* Here the output X is stored in B
|
|
* The custom-kernel will be called only when M*(M+N)* sizeof(Matrix Elements) < L3 cache
|
|
*/
|
|
err_t bli_trsm_small
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
#ifdef BLIS_ENABLE_MULTITHREADING
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#endif
|
|
|
|
dim_t m = bli_obj_length(b);
|
|
dim_t n = bli_obj_width(b);
|
|
|
|
if(!(m && n))
|
|
return BLIS_SUCCESS;
|
|
|
|
|
|
// If alpha is zero, B matrix will become zero after scaling & hence solution is also zero matrix
|
|
if (bli_obj_equals(alpha, &BLIS_ZERO))
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha
|
|
}
|
|
// We have to call matrix scaling if alpha != 1.0
|
|
|
|
// if row major format return. Check this again.
|
|
if ((bli_obj_row_stride(a) != 1) ||
|
|
(bli_obj_row_stride(b) != 1))
|
|
{
|
|
return BLIS_INVALID_ROW_STRIDE;
|
|
}
|
|
|
|
num_t dt = ((*b).info & (0x7 << 0));
|
|
|
|
// only float and double datatypes are supported as of now.
|
|
if (dt != BLIS_DOUBLE && dt != BLIS_FLOAT)
|
|
{
|
|
return BLIS_EXPECTED_REAL_DATATYPE;
|
|
}
|
|
|
|
// A is expected to be triangular in trsm
|
|
if (!bli_obj_is_upper_or_lower (a))
|
|
{
|
|
return BLIS_EXPECTED_TRIANGULAR_OBJECT;
|
|
}
|
|
|
|
// can use other control structs - even can use array of function pointers,
|
|
// indexed by a number with bits formed by f('side', 'uplo', 'transa', dt).
|
|
// In the below implementation, based on the number of finally implemented
|
|
// cases, can move the checks with more cases higher up.
|
|
|
|
if(side == BLIS_LEFT)
|
|
{
|
|
if(bli_obj_has_trans(a))
|
|
{
|
|
if(dt == BLIS_DOUBLE)
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_dtrsm_small_AutXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
//return bli_dtrsm_small_AltXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
return bli_strsm_small_AutXB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
else
|
|
{
|
|
//return bli_strsm_small_AltXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(dt == BLIS_DOUBLE)
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_dtrsm_small_AuXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_AlXB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_AlXB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_strsm_small_AuXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
return bli_strsm_small_AlXB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_has_trans(a))
|
|
{
|
|
if(dt == BLIS_DOUBLE)
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_XAutB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_XAutB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_XAltB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_XAltB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_strsm_small_XAutB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
return bli_strsm_small_XAltB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(dt == BLIS_DOUBLE)
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_XAuB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_XAuB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_XAlB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_XAlB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_strsm_small_XAuB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
//return bli_strsm_small_XAlB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
}
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
};
|
|
|
|
/* TRSM scalar code for the case AX = alpha * B
|
|
* A is lower-triangular, non-unit-diagonal, no transpose
|
|
* Dimensions: A: mxm X: mxn B:mxn
|
|
*/
|
|
|
|
static err_t dtrsm_small_AlXB (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for (k = 0; k < M; k++)
|
|
{
|
|
double lkk_inv = 1.0/A[k+k*lda];
|
|
for (j = 0; j < N; j++)
|
|
{
|
|
B[k + j*ldb] *= lkk_inv;
|
|
for (i = k+1; i < M; i++)
|
|
{
|
|
B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb];
|
|
}
|
|
}
|
|
}// k -loop
|
|
return BLIS_SUCCESS;
|
|
}// end of function
|
|
|
|
/* TRSM scalar code for the case AX = alpha * B
|
|
* A is lower-triangular, unit-diagonal, no transpose
|
|
* Dimensions: A: mxm X: mxn B:mxn
|
|
*/
|
|
|
|
static err_t dtrsm_small_AlXB_unitDiag (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for (k = 0; k < M; k++)
|
|
{
|
|
for (j = 0; j < N; j++)
|
|
{
|
|
for (i = k+1; i < M; i++)
|
|
{
|
|
B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}// end of function
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is upper-triangular, non-unit-diagonal no transpose
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAuB (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
for(k = 0; k < N; k++)
|
|
{
|
|
double lkk_inv = 1.0/A[k+k*lda];
|
|
for(i = 0; i < M; i++)
|
|
{
|
|
B[i+k*ldb] *= lkk_inv;
|
|
for(j = k+1; j < N; j++)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda];
|
|
}
|
|
}
|
|
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is lower-triangular, non-unit triangular, no transpose
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
|
|
static err_t dtrsm_small_XAlB (
|
|
double *A,
|
|
double *B,
|
|
double alpha,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
for(i = 0; i < M; i++)
|
|
for(j = 0; j < N; j++)
|
|
B[i+j*ldb] *= alpha;
|
|
|
|
for(k = N-1; k+1 > 0; k--)
|
|
{
|
|
double lkk_inv = 1.0/A[k+k*lda];
|
|
for(i = M-1; i+1 > 0; i--)
|
|
{
|
|
B[i+k*ldb] *= lkk_inv;
|
|
for(j = k-1; j+1 > 0; j--)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is lower-triangular, unit-diagonal, no transpose
|
|
*Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAlB_unitDiag(
|
|
double *A,
|
|
double *B,
|
|
double alpha,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(i = 0 ; i < M; i++)
|
|
for(j = 0; j < N; j++)
|
|
B[i+j*ldb] *= alpha;
|
|
|
|
for(k = N-1; k+1 > 0; k--)
|
|
{
|
|
for(i = M-1; i+1 > 0; i--)
|
|
{
|
|
for(j = k-1; j+1 > 0; j--)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
*A is upper-triangular, non-unit-diagonal, A is transposed
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAutB (
|
|
double *A,
|
|
double *B,
|
|
double alpha,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(i = 0; i < M; i++)
|
|
for(j = 0; j < N; j++)
|
|
B[i+j*ldb] *=alpha;
|
|
|
|
for(k = N-1; k+1 > 0; k--)
|
|
{
|
|
double lkk_inv = 1.0/A[k+k*lda];
|
|
for(i = M-1; i+1 > 0; i--)
|
|
{
|
|
B[i+k*ldb] *= lkk_inv;
|
|
for(j = k-1; j+1 > 0; j--)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is upper-triangular, unit-diagonal, A has to be transposed
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAutB_unitDiag(
|
|
double *A,
|
|
double *B,
|
|
double alpha,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(i = 0; i< M; i++)
|
|
for(j = 0; j< N; j++)
|
|
B[i+j*ldb] *= alpha;
|
|
|
|
for(i = M-1; i+1 > 0; i--)
|
|
{
|
|
for(j = N-1; j+1 > 0; j--)
|
|
{
|
|
for(k = j-1; k+1 > 0; k--)
|
|
{
|
|
B[i+k*ldb] -= B[i+j*ldb] * A[k+j*lda];
|
|
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is lower-triangular, non-unit-diagonal, A has to be transposed
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAltB (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(k = 0; k < N; k++)
|
|
{
|
|
double lkk_inv = 1.0/A[k+k*lda];
|
|
for(i = 0; i < M; i++)
|
|
{
|
|
B[i+k*ldb] *= lkk_inv;
|
|
for(j = k+1; j < N; j++)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for XA = alpha * B
|
|
* A is lower-triangular, unit-diagonal, A has to be transposed
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAltB_unitDiag(
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(k = 0; k < N; k++)
|
|
{
|
|
for(i = 0; i < M; i++)
|
|
{
|
|
for(j = k+1; j < N; j++)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is upper-triangular, unit-diagonal, no transpose
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAuB_unitDiag (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(k = 0; k < N; k++)
|
|
{
|
|
for(i = 0; i < M; i++)
|
|
{
|
|
for(j = k+1; j < N; j++)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM for the case AX = alpha * B, Double precision
|
|
* A is lower-triangular, no-transpose, non-unit diagonal
|
|
* dimensions A: mxm X: mxn B: mxn
|
|
|
|
b01--->
|
|
* *****************
|
|
** * * * * *
|
|
* * * * * * *
|
|
* * *b01* * * *
|
|
* * * * * * *
|
|
a10 ****** b11 *****************
|
|
| * * * | * * * * *
|
|
| * * * | * * * * *
|
|
| *a10*a11* | *b11* * * *
|
|
v * * * v * * * * *
|
|
*********** *****************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
**************** *****************
|
|
a11--->
|
|
*/
|
|
|
|
static err_t bli_dtrsm_small_AlXB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
|
|
dim_t D_MR = 4; //size of block along 'M' dimpension
|
|
dim_t D_NR = 8; //size of block along 'N' dimension
|
|
|
|
dim_t m = bli_obj_length(b); // number of rows of matrix B
|
|
dim_t n = bli_obj_width(b); // number of columns of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t m_remainder = m % D_MR; //number of remainder rows
|
|
dim_t n_remainder = n % D_NR; //number of remainder columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); // column stride of A
|
|
dim_t cs_b = bli_obj_col_stride(b); // column stride of B
|
|
|
|
dim_t i, j, k; //loop variables
|
|
dim_t k_iter; //number of times GEMM to be performed
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM
|
|
double *ptr_b01_dup;
|
|
|
|
double ones = 1.0;
|
|
|
|
//scratch registers
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
|
|
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4)
|
|
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
///GEMM code begins///
|
|
|
|
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM
|
|
}
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha
|
|
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0]
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1]
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2]
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3]
|
|
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4]
|
|
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5]
|
|
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6]
|
|
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7]
|
|
|
|
///implement TRSM///
|
|
|
|
///transpose of B11//
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//broadcast diagonal elements of A11
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm3, ymm4); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
|
|
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//extract a00
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
|
|
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1]
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1]
|
|
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//extract a22
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A110[][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(ROw2): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5]
|
|
|
|
//perform mul operation
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /= A11[2][2]
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2]
|
|
|
|
a11 += cs_a;
|
|
|
|
//extract a33
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11);//1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
|
|
|
|
//(ROw2): FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6]
|
|
|
|
//perform mul operation
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm1); //B11[0-3][3] /= A11[3][3]
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm1); //B11[0-3][7] /= A11[3][3]
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3]
|
|
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3]
|
|
}
|
|
|
|
if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4)
|
|
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
///GEMM code Begins///
|
|
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] )
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
|
|
|
|
b01 += 1; //move to next row of B01
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0]
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1]
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2]
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3]
|
|
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4]
|
|
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5]
|
|
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6]
|
|
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7]
|
|
|
|
///implement TRSM///
|
|
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//broadcast diagonal elements of A11
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm3, ymm4); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
|
|
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a00
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
|
|
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= B11[0-3][0]*A11[3][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= B11[0-3][4]*A11[3][4]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1]
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1]
|
|
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//extract a22
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(ROw2): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5]
|
|
|
|
//perform mul operation
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /=A11[2][2]
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2]
|
|
|
|
a11 += cs_a;
|
|
|
|
//extract a33
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
|
|
|
|
//(ROw2): FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[0-3][3] -= A11[3][2]*B11[0-3][2]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[0-3][7] -= A11[3][2]*B11[0-3][6]
|
|
|
|
//perform mul operation
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm1); //B11[0-3][3] /= A11[3][3]
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm1); //B11[0-3][7] /= A11[3][3]
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
|
|
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b * 7)); //load B11[0-3][7]
|
|
//determine correct values to store
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08);
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08);
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08);
|
|
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08);
|
|
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30);
|
|
ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30);
|
|
ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30);
|
|
ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30);
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E);
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E);
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E);
|
|
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E);
|
|
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store(B11[0-3][7])
|
|
|
|
}
|
|
}
|
|
|
|
if((n & 4)) //implementation for remainder columns(when 'N' is a multiple of 4)
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4)
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B01[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B01[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B01[0-3][3] *alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
/*
|
|
mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg);
|
|
*/
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
/*
|
|
mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg);
|
|
*/
|
|
//extract a00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
|
|
|
|
//extract diag a11 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3]
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
|
|
|
|
|
|
//extract diag a22 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2]
|
|
|
|
//extract diag a33 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3])
|
|
|
|
}
|
|
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++) //looop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
/*
|
|
mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg);
|
|
*/
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
/*
|
|
mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg);
|
|
*/
|
|
//extract a00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
|
|
|
|
//extract diag a11 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]* B11[0][0-3]
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
|
|
|
|
|
|
//extract diag a22 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]* B11[1][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2]
|
|
|
|
//extract diag a33 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]* B11[2][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//load 4x4 block from b11
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
|
|
|
|
//determine correct values to store
|
|
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3])
|
|
|
|
}
|
|
|
|
n_remainder -= 4;
|
|
j += 4;
|
|
|
|
}
|
|
|
|
if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm3 = _mm256_broadcast_sd((double const*)&ones);
|
|
}
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
/*
|
|
mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg);
|
|
*/
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
/*
|
|
mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg);
|
|
*/
|
|
//extract a00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
|
|
|
|
//extract diag a11 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3]
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
|
|
|
|
|
|
//extract diag a22 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2]
|
|
|
|
//extract diag a33 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
}
|
|
|
|
}
|
|
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
|
|
k_iter = i / D_MR; //number of times GEMM operations to be performed
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
|
|
//load 4x4 block from b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[0][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[1][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[2][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[3][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] * alpha -= ymm4
|
|
ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] * alpha -= ymm5
|
|
ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] * alpha -= ymm6
|
|
ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//determine correct values to store
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08);
|
|
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm11, ymm3, 0x30);
|
|
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E);
|
|
}
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
}
|
|
|
|
///scalar code for trsm without alpha///
|
|
dtrsm_small_AlXB(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM for the case AX = alpha * B, Double precision
|
|
* A is lower-triangular, no-transpose, unit diagonal
|
|
* dimensions A: mxm X: mxn B: mxn
|
|
|
|
b01--->
|
|
* *****************
|
|
** * * * * *
|
|
* * * * * * *
|
|
* * *b01* * * *
|
|
* * * * * * *
|
|
a10 ****** b11 *****************
|
|
| * * * | * * * * *
|
|
| * * * | * * * * *
|
|
| *a10*a11* | *b11* * * *
|
|
v * * * v * * * * *
|
|
*********** *****************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
**************** *****************
|
|
a11--->
|
|
*/
|
|
|
|
static err_t bli_dtrsm_small_AlXB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
|
|
dim_t D_MR = 4; //size of block along 'M' dimpension
|
|
dim_t D_NR = 8; //size of block along 'N' dimension
|
|
|
|
dim_t m = bli_obj_length(b); // number of rows of matrix B
|
|
dim_t n = bli_obj_width(b); // number of columns of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t m_remainder = m % D_MR; //number of remainder rows
|
|
dim_t n_remainder = n % D_NR; //number of remainder columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); // column stride of A
|
|
dim_t cs_b = bli_obj_col_stride(b); // column stride of B
|
|
|
|
dim_t i, j, k; //loop variables
|
|
dim_t k_iter; //number of times GEMM to be performed
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM
|
|
double *ptr_b01_dup;
|
|
|
|
double ones = 1.0;
|
|
|
|
//scratch registers
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
|
|
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4)
|
|
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
///GEMM code begins///
|
|
|
|
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM
|
|
}
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha
|
|
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0]
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1]
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2]
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3]
|
|
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4]
|
|
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5]
|
|
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6]
|
|
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7]
|
|
|
|
///implement TRSM///
|
|
|
|
///transpose of B11//
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
//broadcast diagonal elements of A11
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4]
|
|
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(ROw2): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(ROw2): FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6]
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3]
|
|
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3]
|
|
}
|
|
|
|
if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4)
|
|
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
///GEMM code Begins///
|
|
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] )
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
|
|
|
|
b01 += 1; //move to next row of B01
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0]
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1]
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2]
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3]
|
|
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4]
|
|
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5]
|
|
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6]
|
|
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7]
|
|
|
|
///implement TRSM///
|
|
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
//broadcast diagonal elements of A11
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_b +1)); //A11[1][1]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_b*2 + 2)); //A11[2][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_b*3 + 3)); //A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= B11[0-3][0]*A11[3][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= B11[0-3][4]*A11[3][4]
|
|
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(ROw2): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(ROw2): FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[0-3][3] -= A11[3][2]*B11[0-3][2]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[0-3][7] -= A11[3][2]*B11[0-3][6]
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
|
|
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b * 7)); //load B11[0-3][7]
|
|
//determine correct values to store
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08);
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08);
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08);
|
|
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08);
|
|
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30);
|
|
ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30);
|
|
ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30);
|
|
ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30);
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E);
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E);
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E);
|
|
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E);
|
|
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store(B11[0-3][7])
|
|
|
|
}
|
|
}
|
|
|
|
if((n & 4)) //implementation for remainder columns(when 'N' is a multiple of 4)
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4)
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B01[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B01[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B01[0-3][3] *alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3])
|
|
|
|
}
|
|
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++) //looop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]* B11[0][0-3]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]* B11[1][0-3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]* B11[2][0-3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//load 4x4 block from b11
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
|
|
|
|
//determine correct values to store
|
|
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3])
|
|
|
|
}
|
|
|
|
n_remainder -= 4;
|
|
j += 4;
|
|
|
|
}
|
|
|
|
if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm3 = _mm256_broadcast_sd((double const*)&ones);
|
|
}
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] *alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
}
|
|
|
|
}
|
|
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
|
|
k_iter = i / D_MR; //number of times GEMM operations to be performed
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
|
|
//load 4x4 block from b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[0][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[1][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[2][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B10[3][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); //B11[0-3][0] * alpha -= ymm4
|
|
ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); //B11[0-3][1] * alpha -= ymm5
|
|
ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm6); //B11[0-3][2] * alpha -= ymm6
|
|
ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm7); //B11[0-3][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//determine correct values to store
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x08);
|
|
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm11, ymm3, 0x30);
|
|
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm11, ymm3, 0x0E);
|
|
}
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
}
|
|
|
|
///scalar code for trsm without alpha///
|
|
dtrsm_small_AlXB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is upper triangular, non-unit diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* b11---> a01 ---->
|
|
***************** ***********
|
|
*b01*b11* * * * * * *
|
|
b11 * * * * * **a01 * * a11
|
|
| ***************** ********* |
|
|
| * * * * * *a11* * |
|
|
| * * * * * * * * |
|
|
v ***************** ****** v
|
|
* * * * * * *
|
|
* * * * * * *
|
|
***************** * *
|
|
*
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAuB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
dim_t m_remainder = m % D_MR; //number of corner rows
|
|
dim_t n_remainder = n % D_NR; //number of corner columns
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used in GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3]
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3]
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4)
|
|
{
|
|
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));//A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i += 4;
|
|
}
|
|
if(m_remainder) ///omplementation for remainder rows
|
|
{
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM implementation stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM implementation ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11));
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
|
|
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
ymm7 = _mm256_fmsub_pd(ymm3, ymm15, ymm7);
|
|
///implement TRSM///
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm7, ymm3, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm7,ymm3,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm7,ymm3,0x0E);
|
|
}
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
//scalar code for TRSM
|
|
dtrsm_small_XAuB(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is upper triangular, unit-diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* b11---> a01 ---->
|
|
***************** ***********
|
|
*b01*b11* * * * * * *
|
|
b11 * * * * * **a01 * * a11
|
|
| ***************** ********* |
|
|
| * * * * * *a11* * |
|
|
| * * * * * * * * |
|
|
v ***************** ****** v
|
|
* * * * * * *
|
|
* * * * * * *
|
|
***************** * *
|
|
*
|
|
|
|
*/
|
|
|
|
static err_t bli_dtrsm_small_XAuB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
dim_t m_remainder = m % D_MR; //number of corner rows
|
|
dim_t n_remainder = n % D_NR; //number of corner columns
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used in GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3]
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4)
|
|
{
|
|
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));//A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i += 4;
|
|
}
|
|
if(m_remainder) ///omplementation for remainder rows
|
|
{
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM implementation stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM implementation ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11));
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
|
|
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
ymm7 = _mm256_fmsub_pd(ymm3, ymm15, ymm7);
|
|
///implement TRSM///
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm7, ymm3, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm7,ymm3,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm7,ymm3,0x0E);
|
|
}
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
//scalar code for TRSM
|
|
dtrsm_small_XAuB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, non-unit diagonal, transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* b11---> a01 ---->
|
|
***************** ***********
|
|
*b01*b11* * * * * * *
|
|
b11 * * * * * **a01 * * a11
|
|
| ***************** ********* |
|
|
| * * * * * *a11* * |
|
|
| * * * * * * * * |
|
|
v ***************** ****** v
|
|
* * * * * * *
|
|
* * * * * * *
|
|
***************** * *
|
|
*
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAltB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
dim_t m_remainder = m % D_MR; //number of corner rows
|
|
dim_t n_remainder = n % D_NR; //number of corner columns
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ROME)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used in GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3]
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3]
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4)
|
|
{
|
|
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//4th row
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//4th row
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i += 4;
|
|
}
|
|
if(m_remainder) ///omplementation for remainder rows
|
|
{
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM implementation stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM implementation ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//4th row
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11));
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
|
|
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
ymm7 = _mm256_fmsub_pd(ymm3, ymm15, ymm7);
|
|
///implement TRSM///
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm7, ymm3, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm7,ymm3,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm7,ymm3,0x0E);
|
|
}
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
//scalar code for TRSM
|
|
dtrsm_small_XAltB(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, unit-diagonal, transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* b11---> a01 ---->
|
|
***************** ***********
|
|
*b01*b11* * * * * * *
|
|
b11 * * * * * **a01 * * a11
|
|
| ***************** ********* |
|
|
| * * * * * *a11* * |
|
|
| * * * * * * * * |
|
|
v ***************** ****** v
|
|
* * * * * * *
|
|
* * * * * * *
|
|
***************** * *
|
|
*
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAltB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
dim_t m_remainder = m % D_MR; //number of corner rows
|
|
dim_t n_remainder = n % D_NR; //number of corner columns
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ROME)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used in GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm15 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4)
|
|
{
|
|
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//4th row
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//4th row
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i += 4;
|
|
}
|
|
if(m_remainder) ///omplementation for remainder rows
|
|
{
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM implementation stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM implementation ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//4th row
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11));
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
|
|
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
}
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
ymm7 = _mm256_fmsub_pd(ymm3, ymm15, ymm7);
|
|
///implement TRSM///
|
|
if(m_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm7, ymm3, 0x08);
|
|
}
|
|
if(m_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm7,ymm3,0x30);
|
|
}
|
|
if(m_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm7,ymm3,0x0E);
|
|
}
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
//scalar code for TRSM
|
|
dtrsm_small_XAltB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, non-unit diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* <---b11 <---a11
|
|
***************** *
|
|
*b01*b11* * * * *
|
|
^ * * * * * ^ * *
|
|
| ***************** | *******
|
|
| * * * * * | * * *
|
|
| * * * * * a01* * *
|
|
b10 ***************** *************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
***************** *******************
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAlB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
dim_t m_remainder = m % D_MR; //number of corner rows
|
|
dim_t n_remainder = n % D_NR; //number of corner columns
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ROME)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1): FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1): FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if(i<0)
|
|
i += D_NR;
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4)
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15);
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
}
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15);
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15);
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i -= 4;
|
|
}
|
|
// if(i < 0) i = 0;
|
|
if(m_remainder) ///implementation for remainder rows
|
|
{
|
|
dtrsm_small_XAlB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, unit-diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* <---b11 <---a11
|
|
***************** *
|
|
*b01*b11* * * * *
|
|
^ * * * * * ^ * *
|
|
| ***************** | *******
|
|
| * * * * * | * * *
|
|
| * * * * * a01* * *
|
|
b10 ***************** *************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
***************** *******************
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAlB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
dim_t m_remainder = m % D_MR; //number of corner rows
|
|
dim_t n_remainder = n % D_NR; //number of corner columns
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ROME)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
//(Row 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
//(Row 1): FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
//(Row 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
//(Row 1): FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if(i<0)
|
|
i += D_NR;
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4)
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //store(B11[x][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
}
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm16); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm16); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm16); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm16); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i -= 4;
|
|
}
|
|
if(m_remainder)
|
|
{
|
|
dtrsm_small_XAlB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, non-unit diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* <---b11 <---a11
|
|
***************** *
|
|
*b01*b11* * * * *
|
|
^ * * * * * ^ * *
|
|
| ***************** | *******
|
|
| * * * * * | * * *
|
|
| * * * * * a01* * *
|
|
b10 ***************** *************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
***************** *******************
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAutB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
dim_t m_remainder = m % D_MR; //number of corner rows
|
|
dim_t n_remainder = n % D_NR; //number of corner columns
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ROME)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
//extract a33
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm7);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm7);
|
|
|
|
//extract a22
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm7);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm7);
|
|
|
|
//extract a11
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm7);
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm7);
|
|
|
|
//extract A00
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1):FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm7);
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm7);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
|
|
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
//extract a33
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm7);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm7);
|
|
|
|
//extract a22
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm7);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm7);
|
|
|
|
//extract a11
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm7);
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm7);
|
|
|
|
//extract A00
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1):FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm7);
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm7);
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if(i<0)
|
|
i += D_NR;
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4)
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15);
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
|
|
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
}
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15);
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15);
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i -= 4;
|
|
}
|
|
if(m_remainder) ///implementation for remainder rows
|
|
{
|
|
dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, unit-diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* <---b11 <---a11
|
|
***************** *
|
|
*b01*b11* * * * *
|
|
^ * * * * * ^ * *
|
|
| ***************** | *******
|
|
| * * * * * | * * *
|
|
| * * * * * a01* * *
|
|
b10 ***************** *************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
***************** *******************
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAutB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
dim_t m_remainder = m % D_MR; //number of corner rows
|
|
dim_t n_remainder = n % D_NR; //number of corner columns
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ROME)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
//(Row 1):FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
|
|
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm8 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][2]
|
|
ymm13 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][2]
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones); //B11[0-3][3]
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones); //B11[4-7][3]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
|
|
}
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
//(Row 1):FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
|
|
}
|
|
}
|
|
if(i<0)
|
|
i += D_NR;
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is a multiple of 4)
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
|
|
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
if(n_remainder == 3)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm1 = _mm256_broadcast_sd((double const *)&ones); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm2 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
}
|
|
|
|
//multiply by alpha
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] *= alpha
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] *=alpha
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] *= alpha
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] *= alpha
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
|
|
ymm0 = _mm256_sub_pd(ymm0, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_sub_pd(ymm1, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_sub_pd(ymm2, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_sub_pd(ymm3, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i -= 4;
|
|
}
|
|
if(m_remainder)
|
|
{
|
|
dtrsm_small_XAutB_unitDiag(a->buffer, b->buffer,AlphaVal, m_remainder, n, cs_a, cs_b);
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
/*
|
|
* AX = Alpha*B, Single precision, A: lower triangular
|
|
* This kernel implementation supports matrices A and B such that m is equal to BLI_AlXB_M_SP and n is mutiple of 8
|
|
*/
|
|
|
|
static err_t bli_strsm_small_AlXB (
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
obj_t alpha, beta; // gemm parameters
|
|
obj_t Ga, Gb, Gc; // for GEMM
|
|
int m = bli_obj_length(b); // number of rows of matrix B
|
|
int n = bli_obj_width(b); // number of columns of matrix B
|
|
|
|
int lda = bli_obj_col_stride(a); // column stride of A
|
|
int ldb = bli_obj_col_stride(b); // column stride of B
|
|
|
|
int rsa = bli_obj_row_stride(a); // row stride of A
|
|
int rsb = bli_obj_row_stride(b); // row stride of B
|
|
|
|
int i = 0;
|
|
int j;
|
|
int blk_size = 8;
|
|
int isUnitDiag = bli_obj_has_unit_diag(a);
|
|
|
|
float alphaVal;
|
|
float *L = a->buffer;
|
|
float *B = b->buffer;
|
|
|
|
if (m != BLI_AlXB_M_SP || (n&7) != 0)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
if ( (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM )
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj));
|
|
|
|
/* Small _GEMM preparation code */
|
|
bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &alpha );
|
|
bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &beta );
|
|
|
|
/* B = B - A*B */
|
|
bli_setsc( -(1.0), 0.0, &alpha );
|
|
bli_setsc( (1.0), 0.0, &beta );
|
|
|
|
|
|
bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, blk_size, a->buffer, rsa, lda, &Ga);
|
|
bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gb);
|
|
bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gc);
|
|
|
|
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Ga );
|
|
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gb );
|
|
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gc );
|
|
|
|
//first block of trsm
|
|
Gb.buffer = (void*)(B + i);
|
|
|
|
//trsm of first 8xn block
|
|
if (alphaVal != 1)
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
blis_strsm_microkernel_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
fp_blis_strsm_microkernel = blis_strsm_microkernel;
|
|
}
|
|
else
|
|
{
|
|
blis_strsm_microkernel_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag;
|
|
}
|
|
bli_setsc( alphaVal, 0.0, &beta );
|
|
}
|
|
else
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
fp_blis_strsm_microkernel = blis_strsm_microkernel;
|
|
}
|
|
else
|
|
{
|
|
blis_strsm_microkernel_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag;
|
|
}
|
|
}
|
|
|
|
//gemm update
|
|
for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT
|
|
{
|
|
Ga.buffer = (void*)(L + j + i*lda);
|
|
Gc.buffer = (void*)(B + j);
|
|
|
|
bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb
|
|
}
|
|
|
|
//trsm of remaining blocks
|
|
for (i = blk_size; i < m; i += blk_size)
|
|
{
|
|
Gb.buffer = (void*)(B + i);
|
|
|
|
fp_blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
|
|
for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT
|
|
{
|
|
Ga.buffer = (void*)(L + j + i*lda);
|
|
Gc.buffer = (void*)(B + j);
|
|
|
|
bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb
|
|
}
|
|
|
|
} // End of for loop - i
|
|
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
* XA' = Alpha*B, Single precision, A: lower triangular
|
|
* This kernel implementation supports matrices A and B such that
|
|
* m and n are multiples of 8 and n is less than or equal to BLI_XAltB_N_SP
|
|
*/
|
|
static err_t bli_strsm_small_XAltB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
int m = bli_obj_length(a); // number of rows of matrix B
|
|
int n = bli_obj_length(b); // number of columns of matrix B
|
|
|
|
int lda = bli_obj_col_stride(a); // column stride of A
|
|
int ldb = bli_obj_col_stride(b); // column stride of B
|
|
|
|
int rsa = bli_obj_row_stride(a); // row stride of A
|
|
int rsb = bli_obj_row_stride(b); // row stride of B
|
|
|
|
int i = 0;
|
|
int isUnitDiag = bli_obj_has_unit_diag(a);
|
|
|
|
float alphaVal;
|
|
float *L = a->buffer;
|
|
float *B = b->buffer;
|
|
|
|
if ((m&7) != 0 || (n&7) != 0)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
if ( n > BLI_XAltB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM )
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj));
|
|
|
|
if (alphaVal != 1)
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
trsm_XAtB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
}
|
|
else
|
|
{
|
|
trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
trsm_XAtB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
}
|
|
else
|
|
{
|
|
trsm_XAtB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*
|
|
* A'X = Alpha*B, Single precision, A: upper triangular
|
|
* This kernel implementation supports matrices A and B such that
|
|
* m and n are multiples of 8, m is less than or equal to BLI_AutXB_M_SP and n is less than or equal to BLI_AutXB_N_SP
|
|
*/
|
|
static err_t bli_strsm_small_AutXB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
int m = bli_obj_width(a); // number of rows of matrix A (since At, so width is taken)
|
|
int n = bli_obj_width(b); // number of columns of matrix B
|
|
|
|
int lda = bli_obj_col_stride(a); // column stride of A
|
|
int ldb = bli_obj_col_stride(b); // column stride of B
|
|
|
|
int rsa = bli_obj_row_stride(a); // row stride of A
|
|
int rsb = bli_obj_row_stride(b); // row stride of B
|
|
|
|
int i = 0;
|
|
int isUnitDiag = bli_obj_has_unit_diag(a);
|
|
|
|
float alphaVal;
|
|
float *L = a->buffer;
|
|
float *B = b->buffer;
|
|
|
|
if ((m&7) != 0 || (n&7) != 0)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
if ( m > BLI_AutXB_M_SP || n > BLI_AutXB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM )
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj));
|
|
|
|
if (alphaVal != 1)
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
trsm_AutXB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
}
|
|
else
|
|
{
|
|
trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
trsm_AutXB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
}
|
|
else
|
|
{
|
|
trsm_AutXB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
///////////////////////////// AX=B ///////////////////////////////
|
|
static void blis_strsm_microkernel_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal)
|
|
{
|
|
float ones = 1.0;
|
|
int j;
|
|
int cs_b_offset[6];
|
|
//int row2, row4, row6;
|
|
float *ptr_b_dup;
|
|
|
|
//70 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_cols[8];
|
|
__m256 mat_a_cols_rearr[36];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags;
|
|
__m256 alphaReg;
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alphaVal);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
|
|
//row2 = (cs_l << 1);
|
|
//row4 = (cs_l << 2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
|
|
//row6 = row2 + row4;
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
|
|
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
|
|
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
|
|
|
|
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
|
|
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
|
|
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
|
|
//1st col
|
|
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
|
|
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
|
|
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
|
|
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
|
|
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
|
|
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
|
|
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
|
|
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
|
|
//2nd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//3rd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//4rth col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//5th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//6th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
numCols_b -= 8; // blk_width = 8
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
|
|
mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
|
|
mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
|
|
mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
|
|
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
|
|
|
|
//reciprocal of diagnol elements
|
|
reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
|
|
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Read next set of B columns
|
|
ptr_b += (cs_b + cs_b_offset[5]);
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
|
|
//end loop of cols
|
|
}
|
|
|
|
//Last block trsm processing
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
|
|
//end loop of cols
|
|
}
|
|
|
|
static void blis_strsm_microkernel_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal)
|
|
{
|
|
//float ones = 1.0;
|
|
int j;
|
|
int cs_b_offset[6];
|
|
//int row2, row4, row6;
|
|
float *ptr_b_dup;
|
|
|
|
//70 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_cols[8];
|
|
__m256 mat_a_cols_rearr[36];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags;
|
|
__m256 alphaReg;
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
//reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alphaVal);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
|
|
//row2 = (cs_l << 1);
|
|
//row4 = (cs_l << 2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
|
|
//row6 = row2 + row4;
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
|
|
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
|
|
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
|
|
|
|
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
|
|
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
|
|
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
|
|
//1st col
|
|
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
|
|
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
|
|
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
|
|
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
|
|
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
|
|
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
|
|
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
|
|
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
|
|
//2nd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//3rd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//4rth col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//5th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//6th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//8th col
|
|
//ptr_l += cs_l;
|
|
//mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
numCols_b -= 8; // blk_width = 8
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
|
|
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
|
|
//mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
|
|
//mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
|
|
|
|
//reciprocal of diagnol elements
|
|
//reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
|
|
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
|
|
//extract diag a11 from a
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Read next set of B columns
|
|
ptr_b += (cs_b + cs_b_offset[5]);
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
|
|
//end loop of cols
|
|
}
|
|
|
|
//Last block trsm processing
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
|
|
//extract diag a11 from a
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
|
|
//end loop of cols
|
|
}
|
|
|
|
static void blis_strsm_microkernel_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
//float ones = 1.0;
|
|
int j;
|
|
int cs_b_offset[6];
|
|
//int row2, row4, row6;
|
|
float *ptr_b_dup;
|
|
|
|
//70 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_cols[8];
|
|
__m256 mat_a_cols_rearr[36];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags;
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
//reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
|
|
//row2 = (cs_l << 1);
|
|
//row4 = (cs_l << 2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
|
|
//row6 = row2 + row4;
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
|
|
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
|
|
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
|
|
|
|
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
|
|
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
|
|
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
|
|
//1st col
|
|
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
|
|
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
|
|
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
|
|
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
|
|
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
|
|
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
|
|
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
|
|
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
|
|
//2nd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//3rd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//4rth col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//5th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//6th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//8th col
|
|
//ptr_l += cs_l;
|
|
//mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
numCols_b -= 8; // blk_width = 8
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
|
|
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
|
|
//mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
|
|
//mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
|
|
|
|
//reciprocal of diagnol elements
|
|
//reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
|
|
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
//extract diag a11 from a
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Read next set of B columns
|
|
ptr_b += (cs_b + cs_b_offset[5]);
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
//end loop of cols
|
|
}
|
|
|
|
//Last block trsm processing
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
//extract diag a11 from a
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
//end loop of cols
|
|
}
|
|
|
|
static void blis_strsm_microkernel(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
float ones = 1.0;
|
|
int j;
|
|
int cs_b_offset[6];
|
|
//int row2, row4, row6;
|
|
float *ptr_b_dup;
|
|
|
|
//70 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_cols[8];
|
|
__m256 mat_a_cols_rearr[36];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags;
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
|
|
//row2 = (cs_l << 1);
|
|
//row4 = (cs_l << 2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
|
|
//row6 = row2 + row4;
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
|
|
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
|
|
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
|
|
|
|
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
|
|
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
|
|
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
|
|
//1st col
|
|
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
|
|
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
|
|
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
|
|
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
|
|
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
|
|
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
|
|
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
|
|
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
|
|
//2nd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//3rd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//4rth col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//5th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//6th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
numCols_b -= 8; // blk_width = 8
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
|
|
mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
|
|
mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
|
|
mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
|
|
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
|
|
|
|
//reciprocal of diagnol elements
|
|
reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
|
|
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Read next set of B columns
|
|
ptr_b += (cs_b + cs_b_offset[5]);
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
//end loop of cols
|
|
}
|
|
|
|
//Last block trsm processing
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
//end loop of cols
|
|
}
|
|
static void blis_dtrsm_microkernel_alpha(double *ptr_l,
|
|
double *ptr_b,
|
|
int m,
|
|
int n,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
double alphaVal
|
|
)
|
|
{
|
|
int j;
|
|
int n_remainder = n%4;
|
|
int cs_b_offset[2];
|
|
double *ptr_b_dup;
|
|
double ones = 1.0;
|
|
__m256d mat_b_col[4];
|
|
__m256d mat_b_rearr[4];
|
|
__m256d mat_a_cols[4];
|
|
__m256d mat_a_cols_rearr[10];
|
|
__m256d mat_a_diag_inv[4];
|
|
__m256d reciprocal_diags;
|
|
__m256d alphaReg;
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
|
|
reciprocal_diags = _mm256_broadcast_sd((double const *)&ones);
|
|
alphaReg = _mm256_broadcast_sd((double const *)&alphaVal);
|
|
|
|
//if(m % 4 == 0)
|
|
//{
|
|
//1st col
|
|
mat_a_cols_rearr[0] = _mm256_broadcast_sd((double const *)(ptr_l+0));
|
|
mat_a_cols_rearr[1] = _mm256_broadcast_sd((double const *)(ptr_l+1));
|
|
mat_a_cols_rearr[3] = _mm256_broadcast_sd((double const *)(ptr_l+2));
|
|
mat_a_cols_rearr[6] = _mm256_broadcast_sd((double const *)(ptr_l+3));
|
|
|
|
//2nd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[2] = _mm256_broadcast_sd((double const *)(ptr_l + 1));
|
|
mat_a_cols_rearr[4] = _mm256_broadcast_sd((double const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[7] = _mm256_broadcast_sd((double const *)(ptr_l + 3));
|
|
|
|
//3rd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[5] = _mm256_broadcast_sd((double const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[8] = _mm256_broadcast_sd((double const *)(ptr_l + 3));
|
|
|
|
//4th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[9] = _mm256_broadcast_sd((double const *)(ptr_l + 3));
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
mat_a_diag_inv[0] = _mm256_unpacklo_pd(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
|
|
mat_a_diag_inv[1] = _mm256_unpacklo_pd(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
|
|
|
|
mat_a_diag_inv[0] = _mm256_blend_pd(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x0C);
|
|
reciprocal_diags = _mm256_div_pd(reciprocal_diags, mat_a_diag_inv[0]);
|
|
|
|
for(j = 0;(j+3) < n; j += 4)
|
|
{
|
|
ptr_b_dup = ptr_b;
|
|
/*Shuffle to rearrange/transpose 8x4 block of B into contiguous row-wise registers*/
|
|
|
|
//read first set of 4x4 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_pd((double const *)(ptr_b + (cs_b)));
|
|
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
|
|
mat_b_col[2] = _mm256_loadu_pd((double const *)(ptr_b + cs_b_offset[0]));
|
|
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
|
|
mat_b_col[3] = _mm256_loadu_pd((double const *)(ptr_b + cs_b_offset[1]));
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[1] = _mm256_unpacklo_pd(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_pd(mat_b_col[2], mat_b_col[3]);
|
|
|
|
//rearrange low elements
|
|
mat_b_rearr[0] = _mm256_permute2f128_pd(mat_b_rearr[1],mat_b_rearr[3],0x20);
|
|
mat_b_rearr[2] = _mm256_permute2f128_pd(mat_b_rearr[1],mat_b_rearr[3],0x31);
|
|
|
|
mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_pd(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_pd(mat_b_col[2], mat_b_col[3]);
|
|
|
|
//rearrange high elements
|
|
mat_b_rearr[1] = _mm256_permute2f128_pd(mat_b_col[0],mat_b_col[1],0x20);
|
|
mat_b_rearr[3] = _mm256_permute2f128_pd(mat_b_col[0],mat_b_col[1],0x31);
|
|
|
|
mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg);
|
|
//extract a00
|
|
mat_a_diag_inv[0] = _mm256_permute_pd(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_pd(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_pd(reciprocal_diags, 0x03);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_pd(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_pd(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_pd(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_pd(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_pd(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x11);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_pd(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_pd(reciprocal_diags, 0x0C);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_pd(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x11);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[1] = _mm256_unpacklo_pd(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[3] = _mm256_unpacklo_pd(mat_b_rearr[2], mat_b_rearr[3]);
|
|
|
|
//rearrange low elements
|
|
mat_a_cols[0] = _mm256_permute2f128_pd(mat_a_cols[1],mat_a_cols[3],0x20);
|
|
mat_a_cols[2] = _mm256_permute2f128_pd(mat_a_cols[1],mat_a_cols[3],0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_pd(mat_b_rearr[0], mat_b_rearr[1]);
|
|
|
|
mat_b_rearr[1] = _mm256_unpackhi_pd(mat_b_rearr[2], mat_b_rearr[3]);
|
|
|
|
//rearrange high elements
|
|
mat_a_cols[1] = _mm256_permute2f128_pd(mat_b_rearr[0],mat_b_rearr[1],0x20);
|
|
mat_a_cols[3] = _mm256_permute2f128_pd(mat_b_rearr[0],mat_b_rearr[1],0x31);
|
|
|
|
//Read next set of B columns
|
|
ptr_b += (cs_b+cs_b_offset[1]);
|
|
_mm256_storeu_pd((double *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_pd((double *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_pd((double *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_pd((double *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
|
|
}
|
|
ptr_b_dup = ptr_b;
|
|
if(n_remainder == 3)
|
|
{
|
|
|
|
//read first set of 4x4 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_pd((double const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_loadu_pd((double const *)(ptr_b + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
//read first set of 4x4 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_pd((double const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_broadcast_sd((double const *)&ones);
|
|
mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
//read first set of 4x4 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_pd((double const *)ptr_b);
|
|
mat_b_col[1] = _mm256_broadcast_sd((double const *)&ones);
|
|
mat_b_col[2] = _mm256_broadcast_sd((double const *)&ones);
|
|
mat_b_col[3] = _mm256_broadcast_sd((double const *)&ones);
|
|
}
|
|
/*Shuffle to rearrange/transpose 8x4 block of B into contiguous row-wise registers*/
|
|
////unpacklow////
|
|
mat_b_rearr[1] = _mm256_unpacklo_pd(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_pd(mat_b_col[2], mat_b_col[3]);
|
|
//rearrange low elements
|
|
mat_b_rearr[0] = _mm256_permute2f128_pd(mat_b_rearr[1],mat_b_rearr[3],0x20);
|
|
mat_b_rearr[2] = _mm256_permute2f128_pd(mat_b_rearr[1],mat_b_rearr[3],0x31);
|
|
mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], alphaReg);
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_pd(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_pd(mat_b_col[2], mat_b_col[3]);
|
|
//rearrange high elements
|
|
mat_b_rearr[1] = _mm256_permute2f128_pd(mat_b_col[0],mat_b_col[1],0x20);
|
|
mat_b_rearr[3] = _mm256_permute2f128_pd(mat_b_col[0],mat_b_col[1],0x31);
|
|
mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], alphaReg);
|
|
//extract a00
|
|
mat_a_diag_inv[0] = _mm256_permute_pd(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_pd(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_pd(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_pd(reciprocal_diags, 0x03);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_pd(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_pd(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_pd(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_pd(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_pd(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_pd(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x11);
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_pd(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_pd(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_pd(reciprocal_diags, 0x0C);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_pd(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x11);
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_pd(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_pd(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[1] = _mm256_unpacklo_pd(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[3] = _mm256_unpacklo_pd(mat_b_rearr[2], mat_b_rearr[3]);
|
|
//rearrange low elements
|
|
mat_a_cols[0] = _mm256_permute2f128_pd(mat_a_cols[1],mat_a_cols[3],0x20);
|
|
mat_a_cols[2] = _mm256_permute2f128_pd(mat_a_cols[1],mat_a_cols[3],0x31);
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_pd(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_pd(mat_b_rearr[2], mat_b_rearr[3]);
|
|
//rearrange high elements
|
|
mat_a_cols[1] = _mm256_permute2f128_pd(mat_b_rearr[0],mat_b_rearr[1],0x20);
|
|
mat_a_cols[3] = _mm256_permute2f128_pd(mat_b_rearr[0],mat_b_rearr[1],0x31);
|
|
//Store the computed B columns
|
|
if(n_remainder == 3)
|
|
{
|
|
_mm256_storeu_pd((double *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_pd((double *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_pd((double *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
}
|
|
if(n_remainder == 2)
|
|
{
|
|
_mm256_storeu_pd((double *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_pd((double *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
_mm256_storeu_pd((double *)ptr_b_dup, mat_a_cols[0]);
|
|
}
|
|
|
|
//}
|
|
|
|
}
|
|
|
|
#if OPT_CACHE_BLOCKING_L1 //new intrinsic kernels
|
|
static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += cs_l_offset[6];
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
i = i1 + r;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
#endif
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2 + r;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
i4 = k >> 3;
|
|
ptr_l_dup += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
i += cs_l_offset[6];
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
{
|
|
i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += cs_l_offset[6];
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
i = i1 + r;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
#endif
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2 + r;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
i4 = k >> 3;
|
|
ptr_l_dup += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
i += cs_l_offset[6];
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
{
|
|
i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
//(Row0)
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
i3 += cs_l_offset[6];
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
i = i1 + r;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
#endif
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2 + r;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
i4 = k >> 3;
|
|
ptr_l_dup += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
i += cs_l_offset[6];
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
{
|
|
i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): already done
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
|
|
//(Row0)
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
i3 += cs_l_offset[6];
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
i = i1 + r;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
#endif
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2 + r;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
i4 = k >> 3;
|
|
ptr_l_dup += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
i += cs_l_offset[6];
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
{
|
|
i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): already done
|
|
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
#else //rel 1.0 intrisic kernels (NOT OPT_CACHE_BLOCKING_L1)
|
|
static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[16][8];
|
|
__m256 mat_a_cols_rearr[8];
|
|
__m256 mat_a_blk_elems[64];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += cs_l_offset[6];
|
|
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (k = 0; k < numCols_b; k += 8)
|
|
{
|
|
i = i1 + k;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
i2++;
|
|
}
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
|
|
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
|
|
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
|
|
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
|
|
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
|
|
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
|
|
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
|
|
|
|
// _mm256_permute2f128_ps()
|
|
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
|
|
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
|
|
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
|
|
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
|
|
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
|
|
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
|
|
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
|
|
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
|
|
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
|
|
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
|
|
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
|
|
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
|
|
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
|
|
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
|
|
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
|
|
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
|
|
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
|
|
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
|
|
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
|
|
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
|
|
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
|
|
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
|
|
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
|
|
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
|
|
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
|
|
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
|
|
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
|
|
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
|
|
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
|
|
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
|
|
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
|
|
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
|
|
|
|
i += cs_l_offset[6];
|
|
|
|
|
|
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
|
|
i4 = i2 + k;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
i4 = k >> 3;
|
|
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
k = 0;
|
|
for (i = 0; i < numCols_b; i+=8)
|
|
{
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
|
|
|
|
}
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[16][8];
|
|
__m256 mat_a_cols_rearr[8];
|
|
__m256 mat_a_blk_elems[64];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg);
|
|
mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg);
|
|
mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg);
|
|
mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg);
|
|
mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg);
|
|
mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg);
|
|
mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg);
|
|
mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += cs_l_offset[6];
|
|
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (k = 0; k < numCols_b; k += 8)
|
|
{
|
|
i = i1 + k;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg);
|
|
mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg);
|
|
mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg);
|
|
mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg);
|
|
mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg);
|
|
mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg);
|
|
mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg);
|
|
mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg);
|
|
|
|
i2++;
|
|
}
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
|
|
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
|
|
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
|
|
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
|
|
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
|
|
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
|
|
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
|
|
|
|
// _mm256_permute2f128_ps()
|
|
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
|
|
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
|
|
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
|
|
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
|
|
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
|
|
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
|
|
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
|
|
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
|
|
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
|
|
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
|
|
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
|
|
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
|
|
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
|
|
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
|
|
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
|
|
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
|
|
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
|
|
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
|
|
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
|
|
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
|
|
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
|
|
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
|
|
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
|
|
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
|
|
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
|
|
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
|
|
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
|
|
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
|
|
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
|
|
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
|
|
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
|
|
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
|
|
|
|
i += cs_l_offset[6];
|
|
|
|
|
|
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
|
|
i4 = i2 + k;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
i4 = k >> 3;
|
|
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
k = 0;
|
|
for (i = 0; i < numCols_b; i+=8)
|
|
{
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
|
|
k++;
|
|
}
|
|
|
|
|
|
}
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[16][8];
|
|
//__m256 mat_a_cols_rearr[8];
|
|
__m256 mat_a_blk_elems[64];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
//(Row0)
|
|
mat_b_col[0] = mat_b_rearr[0][0];
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
i3 += cs_l_offset[6];
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (k = 0; k < numCols_b; k += 8)
|
|
{
|
|
i = i1 + k;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
i2++;
|
|
}
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
|
|
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
|
|
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
|
|
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
|
|
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
|
|
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
|
|
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
|
|
|
|
// _mm256_permute2f128_ps()
|
|
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
|
|
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
|
|
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
|
|
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
|
|
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
|
|
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
|
|
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
|
|
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
|
|
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
|
|
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
|
|
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
|
|
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
|
|
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
|
|
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
|
|
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
|
|
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
|
|
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
|
|
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
|
|
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
|
|
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
|
|
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
|
|
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
|
|
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
|
|
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
|
|
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
|
|
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
|
|
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
|
|
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
|
|
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
|
|
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
|
|
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
|
|
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
|
|
|
|
i += cs_l_offset[6];
|
|
|
|
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
|
|
i4 = i2 + k;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
i4 = k >> 3;
|
|
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
k = 0;
|
|
for (i = 0; i < numCols_b; i+=8)
|
|
{
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
|
|
//(Row0): already done
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
|
|
|
|
}
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[16][8];
|
|
//__m256 mat_a_cols_rearr[8];
|
|
__m256 mat_a_blk_elems[64];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg);
|
|
mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg);
|
|
mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg);
|
|
mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg);
|
|
mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg);
|
|
mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg);
|
|
mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg);
|
|
mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg);
|
|
|
|
//(Row0)
|
|
mat_b_col[0] = mat_b_rearr[0][0];
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
i3 += cs_l_offset[6];
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (k = 0; k < numCols_b; k += 8)
|
|
{
|
|
i = i1 + k;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg);
|
|
mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg);
|
|
mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg);
|
|
mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg);
|
|
mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg);
|
|
mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg);
|
|
mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg);
|
|
mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg);
|
|
|
|
i2++;
|
|
}
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
|
|
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
|
|
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
|
|
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
|
|
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
|
|
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
|
|
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
|
|
|
|
// _mm256_permute2f128_ps()
|
|
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
|
|
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
|
|
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
|
|
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
|
|
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
|
|
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
|
|
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
|
|
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
|
|
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
|
|
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
|
|
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
|
|
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
|
|
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
|
|
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
|
|
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
|
|
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
|
|
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
|
|
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
|
|
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
|
|
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
|
|
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
|
|
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
|
|
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
|
|
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
|
|
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
|
|
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
|
|
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
|
|
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
|
|
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
|
|
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
|
|
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
|
|
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
|
|
|
|
i += cs_l_offset[6];
|
|
|
|
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
|
|
i4 = i2 + k;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
i4 = k >> 3;
|
|
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
k = 0;
|
|
for (i = 0; i < numCols_b; i+=8)
|
|
{
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
|
|
//(Row0): already done
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
|
|
|
|
}
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
#endif //OPT_CACHE_BLOCKING_L1
|
|
|
|
//////////////////////////// AutX=B ///////////////////////
|
|
static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
|
|
|
|
i += cs_b_offset[6];
|
|
ptr_b_dup += cs_b_offset[6];
|
|
//i += 8;
|
|
//ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += cs_l_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += 8;
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += 8;
|
|
i1 += 8;
|
|
i = i1;
|
|
i2 = 0;
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
#endif
|
|
//i = 0;
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
//{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
|
|
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
|
|
#else
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
|
|
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
|
|
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
|
|
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
|
|
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i4 = k >> 3;
|
|
ptr_l_dup++;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
//}
|
|
//i2 += cs_b_offset[6];
|
|
i4 += 8;
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
//{
|
|
//i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
//}
|
|
i += cs_b_offset[6];
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
|
|
|
|
i += cs_b_offset[6];
|
|
ptr_b_dup += cs_b_offset[6];
|
|
//i += 8;
|
|
//ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += cs_l_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += 8;
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += 8;
|
|
i1 += 8;
|
|
i = i1;
|
|
i2 = 0;
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
#endif
|
|
|
|
//i = 0;
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
//{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
|
|
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
|
|
#else
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
|
|
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
|
|
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
|
|
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
|
|
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i4 = k >> 3;
|
|
ptr_l_dup++;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
//}
|
|
//i2 += cs_b_offset[6];
|
|
i4 += 8;
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
//{
|
|
//i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
//}
|
|
i += cs_b_offset[6];
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
|
|
//(Row0)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
|
|
|
|
i += cs_b_offset[6];
|
|
ptr_b_dup += cs_b_offset[6];
|
|
//i += 8;
|
|
//ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += cs_l_offset[6];
|
|
|
|
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += 8;
|
|
i1 += 8;
|
|
i = i1;
|
|
i2 = 0;
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
#endif
|
|
|
|
//i = 0;
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
//{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
|
|
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
|
|
#else
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
|
|
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
|
|
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
|
|
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
|
|
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i4 = k >> 3;
|
|
ptr_l_dup++;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
//}
|
|
//i2 += cs_b_offset[6];
|
|
i4 += 8;
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
//{
|
|
//i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): already done
|
|
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
|
|
|
|
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
//}
|
|
i += cs_b_offset[6];
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
|
|
//(Row0)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
|
|
|
|
i += cs_b_offset[6];
|
|
ptr_b_dup += cs_b_offset[6];
|
|
//i += 8;
|
|
//ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += cs_l_offset[6];
|
|
|
|
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += 8;
|
|
i1 += 8;
|
|
i = i1;
|
|
i2 = 0;
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
#endif
|
|
|
|
//i = 0;
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
//{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
|
|
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
|
|
#else
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
|
|
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
|
|
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
|
|
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
|
|
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i4 = k >> 3;
|
|
ptr_l_dup++;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
//}
|
|
//i2 += cs_b_offset[6];
|
|
i4 += 8;
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
//{
|
|
//i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): already done
|
|
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
|
|
|
|
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
//}
|
|
i += cs_b_offset[6];
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
#endif
|