mirror of
https://github.com/amd/blis.git
synced 2026-04-29 12:01:12 +00:00
Details: - Added an err_t* parameter to memory allocation functions including bli_malloc_intl(), bli_calloc_intl(), bli_malloc_user(), bli_fmalloc_align(), and bli_fmalloc_noalign(). Since these functions already use the return value to return the allocated memory address, they can't communicate errors to the caller through the return value. This commit does not employ any error checking within these functions or their callers, but this sets up BLIS for a more comprehensive commit that moves in that direction. - Moved the typedefs for malloc_ft and free_ft from bli_malloc.h to bli_type_defs.h. This was done so that what remains of bli_malloc.h can be included after the definition of the err_t enum. (This ordering was needed because bli_malloc.h now contains function prototypes that use err_t.) - Defined bli_is_success() and bli_is_failure() static functions in bli_param_macro_defs.h. These functions provide easy checks for error codes and will be used more heavily in future commits. - Unfortunately, the additional err_t* argument discussed above breaks the API for bli_malloc_user(), which is an exported symbol in the shared library. However, it's quite possible that the only application that calls bli_malloc_user()--indeed, the reason it is was marked for symbol exporting to begin with--is the BLIS testsuite. And if that's the case, this breakage won't affect anyone. Nonetheless, the "major" part of the so_version file has been updated accordingly to 4.0.0.
27822 lines
1.7 MiB
27822 lines
1.7 MiB
/*
|
|
|
|
BLIS
|
|
An object-based framework for developing high-performance BLAS-like
|
|
libraries.
|
|
|
|
Copyright (C) 2018-2019, Advanced Micro Devices, Inc.
|
|
|
|
Redistribution and use in source and binary forms, with or without
|
|
modification, are permitted provided that the following conditions are
|
|
met:
|
|
- Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
- Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in the
|
|
documentation and/or other materials provided with the distribution.
|
|
- Neither the name of The University of Texas at Austin nor the names
|
|
of its contributors may be used to endorse or promote products
|
|
derived from this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
*/
|
|
|
|
#include "blis.h"
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM
|
|
#include "immintrin.h"
|
|
#define GEMM_BLK_V1 8 //Block size to perform gemm and apply trsm
|
|
#define GEMM_ACCUM_A 1 //Peform B1=B1-(B0*A0) operation instead of B1'=(B0*A0) and then B1=B1-B1'
|
|
#define OPT_CACHE_BLOCKING_L1 1 //Perform trsm block-wise in blocks of GEMM_BLK_V1 instead of all columns of B together.
|
|
#define REARRANGE_SHFL 0 //Rearrange operations using blend or shuffle
|
|
#define BLI_AlXB_M_SP 16
|
|
#define BLI_XAltB_N_SP 128
|
|
#define BLI_AutXB_M_SP 64
|
|
#define BLI_AutXB_N_SP 128
|
|
|
|
// XA = B; A is lower-traingular; No transpose; double precision; non-unit diagonal
|
|
static err_t bli_dtrsm_small_XAlB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is lower triabgular; No transpose; double precision; unit-diagonal
|
|
static err_t bli_dtrsm_small_XAlB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is lower-triangular; A is transposed; double precision; non-unit-diagonal
|
|
static err_t bli_dtrsm_small_XAltB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is lower-triangular; A is transposed; double precision; unit-diagonal
|
|
static err_t bli_dtrsm_small_XAltB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
// XA = B; A is upper triangular; No transpose; double presicion; non-unit diagonal
|
|
static err_t bli_dtrsm_small_XAuB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is upper triangular; No transpose; double precision; unit-diagonal
|
|
static err_t bli_dtrsm_small_XAuB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is upper-triangular; A is transposed; double precision; non-unit diagonal
|
|
static err_t bli_dtrsm_small_XAutB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA = B; A is upper-triangular; A is transposed; double precision; unit diagonal
|
|
static err_t bli_dtrsm_small_XAutB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//AX = B; A is lower triangular; No transpose; double precision; non-unit diagonal
|
|
static err_t bli_dtrsm_small_AlXB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//AX = B; A is lower triangular; No transpose; double precision; unit diagonal
|
|
static err_t bli_dtrsm_small_AlXB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
|
|
|
|
static void (*fp_blis_strsm_microkernel)( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
static void blis_strsm_microkernel( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
static void blis_strsm_microkernel_alpha( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alphaVal
|
|
);
|
|
static void blis_strsm_microkernel_unitDiag( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
static void blis_strsm_microkernel_alpha_unitDiag( float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alphaVal
|
|
);
|
|
static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alphaVal);
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alphaVal);
|
|
|
|
|
|
static void blis_dtrsm_microkernel( double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
|
|
static void blis_dtrsm_microkernel_alpha( double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
double alphaVal
|
|
);
|
|
|
|
static void blis_dtrsm_microkernel_unitDiag( double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b
|
|
);
|
|
|
|
static void blis_dtrsm_microkernel_alpha_unitDiag( double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
double alphaVal
|
|
);
|
|
|
|
static void dtrsm_XAtB_block_allSmallSizedMatrices(double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha(double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
double alphaVal);
|
|
static void dtrsm_XAtB_block_allSmallSizedMatrices_unitDiag(double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void dtrsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(double *ptr_l,
|
|
double *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
double alphaVal);
|
|
static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alpha);
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b);
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l,
|
|
float *ptr_b,
|
|
int numRows_lb,
|
|
int numCols_b,
|
|
int rs_l,
|
|
int rs_b,
|
|
int cs_l,
|
|
int cs_b,
|
|
float alpha);
|
|
|
|
//AX = B; A is lower triangular; No transpose; single precision
|
|
static err_t bli_strsm_small_AlXB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
//A.'X = B; A is upper triangular; A has to be transposed; single precision
|
|
static err_t bli_strsm_small_AutXB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//XA.' = B; A is lower triangular; A has to be transposed; single precision
|
|
static err_t bli_strsm_small_XAltB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
//A.'X = B; A is upper triangular; A has to be transposed; double precision
|
|
static err_t bli_dtrsm_small_AutXB
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
);
|
|
|
|
/*
|
|
* The bli_trsm_small implements unpacked version of TRSM
|
|
* Currently only column-major is supported, A & B are column-major
|
|
* Input: A: MxM (triangular matrix)
|
|
* B: MxN matrix
|
|
* Output: X: MxN matrix such that AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B
|
|
* Here the output X is stored in B
|
|
* The custom-kernel will be called only when M*(M+N)* sizeof(Matrix Elements) < L3 cache
|
|
*/
|
|
err_t bli_trsm_small
|
|
(
|
|
side_t side,
|
|
obj_t* alpha,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
#ifdef BLIS_ENABLE_MULTITHREADING
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#endif
|
|
|
|
dim_t m = bli_obj_length(b);
|
|
dim_t n = bli_obj_width(b);
|
|
|
|
if(!(m && n))
|
|
return BLIS_SUCCESS;
|
|
|
|
|
|
// If alpha is zero, B matrix will become zero after scaling & hence solution is also zero matrix
|
|
if (bli_obj_equals(alpha, &BLIS_ZERO))
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha
|
|
}
|
|
// We have to call matrix scaling if alpha != 1.0
|
|
|
|
// if row major format return. Check this again.
|
|
if ((bli_obj_row_stride(a) != 1) ||
|
|
(bli_obj_row_stride(b) != 1))
|
|
{
|
|
return BLIS_INVALID_ROW_STRIDE;
|
|
}
|
|
|
|
num_t dt = ((*b).info & (0x7 << 0));
|
|
|
|
// only float and double datatypes are supported as of now.
|
|
if (dt != BLIS_DOUBLE && dt != BLIS_FLOAT)
|
|
{
|
|
return BLIS_EXPECTED_REAL_DATATYPE;
|
|
}
|
|
|
|
// A is expected to be triangular in trsm
|
|
if (!bli_obj_is_upper_or_lower (a))
|
|
{
|
|
return BLIS_EXPECTED_TRIANGULAR_OBJECT;
|
|
}
|
|
|
|
// can use other control structs - even can use array of function pointers,
|
|
// indexed by a number with bits formed by f('side', 'uplo', 'transa', dt).
|
|
// In the below implementation, based on the number of finally implemented
|
|
// cases, can move the checks with more cases higher up.
|
|
|
|
if(side == BLIS_LEFT)
|
|
{
|
|
if(bli_obj_has_trans(a))
|
|
{
|
|
if(dt == BLIS_DOUBLE)
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_dtrsm_small_AutXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
//return bli_dtrsm_small_AltXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
return bli_strsm_small_AutXB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
else
|
|
{
|
|
//return bli_strsm_small_AltXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(dt == BLIS_DOUBLE)
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_dtrsm_small_AuXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_AlXB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_AlXB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_strsm_small_AuXB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
return bli_strsm_small_AlXB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_has_trans(a))
|
|
{
|
|
if(dt == BLIS_DOUBLE)
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_XAutB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_XAutB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_XAltB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_XAltB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_strsm_small_XAutB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
return bli_strsm_small_XAltB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(dt == BLIS_DOUBLE)
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_XAuB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_XAuB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_has_unit_diag(a))
|
|
return bli_dtrsm_small_XAlB_unitDiag(side, alpha, a, b, cntx, cntl);
|
|
else
|
|
return bli_dtrsm_small_XAlB(side, alpha, a, b, cntx, cntl);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if(bli_obj_is_upper(a))
|
|
{
|
|
//return bli_strsm_small_XAuB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
else
|
|
{
|
|
//return bli_strsm_small_XAlB(side, alpha, a, b, cntx, cntl);
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
}
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
};
|
|
|
|
/* TRSM scalar code for the case AX = alpha * B
|
|
* A is lower-triangular, non-unit-diagonal, no transpose
|
|
* Dimensions: A: mxm X: mxn B:mxn
|
|
*/
|
|
|
|
static err_t dtrsm_small_AlXB (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for (k = 0; k < M; k++)
|
|
{
|
|
double lkk_inv = 1.0/A[k+k*lda];
|
|
for (j = 0; j < N; j++)
|
|
{
|
|
B[k + j*ldb] *= lkk_inv;
|
|
for (i = k+1; i < M; i++)
|
|
{
|
|
B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb];
|
|
}
|
|
}
|
|
}// k -loop
|
|
return BLIS_SUCCESS;
|
|
}// end of function
|
|
|
|
/* TRSM scalar code for the case AX = alpha * B
|
|
* A is lower-triangular, unit-diagonal, no transpose
|
|
* Dimensions: A: mxm X: mxn B:mxn
|
|
*/
|
|
|
|
static err_t dtrsm_small_AlXB_unitDiag (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for (k = 0; k < M; k++)
|
|
{
|
|
for (j = 0; j < N; j++)
|
|
{
|
|
for (i = k+1; i < M; i++)
|
|
{
|
|
B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}// end of function
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is upper-triangular, non-unit-diagonal no transpose
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAuB (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
for(k = 0; k < N; k++)
|
|
{
|
|
double lkk_inv = 1.0/A[k+k*lda];
|
|
for(i = 0; i < M; i++)
|
|
{
|
|
B[i+k*ldb] *= lkk_inv;
|
|
for(j = k+1; j < N; j++)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda];
|
|
}
|
|
}
|
|
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is lower-triangular, non-unit triangular, no transpose
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
|
|
static err_t dtrsm_small_XAlB (
|
|
double *A,
|
|
double *B,
|
|
double alpha,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
for(j = 0; j < N; j++)
|
|
for(i = 0; i < M; i++)
|
|
B[i+j*ldb] *= alpha;
|
|
|
|
for(k = N;k--;)
|
|
{
|
|
double lkk_inv = 1.0/A[(k)+(k)*lda];
|
|
for(i = M;i--;)
|
|
{
|
|
B[(i)+(k)*ldb] *= lkk_inv;
|
|
for(j = k;j--;)
|
|
{
|
|
B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A[(k)+(j)*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is lower-triangular, unit-diagonal, no transpose
|
|
*Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAlB_unitDiag(
|
|
double *A,
|
|
double *B,
|
|
double alpha,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(j = 0 ; j < N; j++)
|
|
for(i = 0; i < M; i++)
|
|
B[i+j*ldb] *= alpha;
|
|
double A_k_j;
|
|
for(k = N; k--;)
|
|
{
|
|
for(j = k; j--;)
|
|
{
|
|
A_k_j = A[(k)+(j)*lda];
|
|
for(i = M; i--;)
|
|
{
|
|
B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A_k_j;
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
*A is upper-triangular, non-unit-diagonal, A is transposed
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAutB (
|
|
double *A,
|
|
double *B,
|
|
double alpha,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(j = 0; j < N; j++)
|
|
for(i = 0; i < M; i++)
|
|
B[i+j*ldb] *=alpha;
|
|
|
|
for(k = N; k--;)
|
|
{
|
|
double lkk_inv = 1.0/A[(k)+(k)*lda];
|
|
for(i = M; i--;)
|
|
{
|
|
B[(i)+(k)*ldb] *= lkk_inv;
|
|
for(j = k; j--;)
|
|
{
|
|
B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A[(j)+(k)*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is upper-triangular, unit-diagonal, A has to be transposed
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAutB_unitDiag(
|
|
double *A,
|
|
double *B,
|
|
double alpha,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
double A_k_j;
|
|
|
|
for(j = 0; j< N; j++)
|
|
for(i = 0; i< M; i++)
|
|
B[i+j*ldb] *= alpha;
|
|
|
|
for(k = N; k--;)
|
|
{
|
|
for(j = k; j--;)
|
|
{
|
|
A_k_j = A[(j)+(k)*lda];
|
|
for(i = M; i--;)
|
|
{
|
|
B[(i)+(j)*ldb] -= B[(i)+(k)*ldb] * A_k_j;
|
|
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is lower-triangular, non-unit-diagonal, A has to be transposed
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAltB (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(k = 0; k < N; k++)
|
|
{
|
|
double lkk_inv = 1.0/A[k+k*lda];
|
|
for(i = 0; i < M; i++)
|
|
{
|
|
B[i+k*ldb] *= lkk_inv;
|
|
for(j = k+1; j < N; j++)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for XA = alpha * B
|
|
* A is lower-triangular, unit-diagonal, A has to be transposed
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAltB_unitDiag(
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(k = 0; k < N; k++)
|
|
{
|
|
for(i = 0; i < M; i++)
|
|
{
|
|
for(j = k+1; j < N; j++)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[j+k*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM scalar code for the case XA = alpha * B
|
|
* A is upper-triangular, unit-diagonal, no transpose
|
|
* Dimensions: X:mxn A:nxn B:mxn
|
|
*/
|
|
static err_t dtrsm_small_XAuB_unitDiag (
|
|
double *A,
|
|
double *B,
|
|
dim_t M,
|
|
dim_t N,
|
|
dim_t lda,
|
|
dim_t ldb
|
|
)
|
|
{
|
|
|
|
dim_t i, j, k;
|
|
|
|
for(k = 0; k < N; k++)
|
|
{
|
|
for(i = 0; i < M; i++)
|
|
{
|
|
for(j = k+1; j < N; j++)
|
|
{
|
|
B[i+j*ldb] -= B[i+k*ldb] * A[k+j*lda];
|
|
}
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM for the case AX = alpha * B, Double precision
|
|
* A is lower-triangular, no-transpose, non-unit diagonal
|
|
* dimensions A: mxm X: mxn B: mxn
|
|
|
|
b01--->
|
|
* *****************
|
|
** * * * * *
|
|
* * * * * * *
|
|
* * *b01* * * *
|
|
* * * * * * *
|
|
a10 ****** b11 *****************
|
|
| * * * | * * * * *
|
|
| * * * | * * * * *
|
|
| *a10*a11* | *b11* * * *
|
|
v * * * v * * * * *
|
|
*********** *****************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
**************** *****************
|
|
a11--->
|
|
*/
|
|
static err_t bli_dtrsm_small_AlXB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
|
|
dim_t D_MR = 4; //size of block along 'M' dimpension
|
|
dim_t D_NR = 8; //size of block along 'N' dimension
|
|
|
|
dim_t m = bli_obj_length(b); // number of rows of matrix B
|
|
dim_t n = bli_obj_width(b); // number of columns of matrix B
|
|
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME)
|
|
|| (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N)
|
|
|| (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n<D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t m_remainder = m & 3; //number of remainder rows
|
|
dim_t n_remainder = n & 7; //number of remainder columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); // column stride of A
|
|
dim_t cs_b = bli_obj_col_stride(b); // column stride of B
|
|
|
|
dim_t i, j, k; //loop variables
|
|
dim_t k_iter; //number of times GEMM to be performed
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM
|
|
double *ptr_b01_dup;
|
|
|
|
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
|
|
double* f_temp;
|
|
|
|
double ones = 1.0;
|
|
|
|
//scratch registers
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
|
|
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4)
|
|
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
///GEMM code begins///
|
|
|
|
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM
|
|
}
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha
|
|
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0]
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1]
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2]
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3]
|
|
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4]
|
|
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5]
|
|
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6]
|
|
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7]
|
|
|
|
///implement TRSM///
|
|
|
|
///transpose of B11//
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//broadcast diagonal elements of A11
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); //A11[2][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+cs_a*3 + 3)); //A11[3][3]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm3, ymm4); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
|
|
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//extract a00
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
|
|
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1]
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1]
|
|
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//extract a22
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A110[][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(ROw2): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5]
|
|
|
|
//perform mul operation
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /= A11[2][2]
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2]
|
|
|
|
a11 += cs_a;
|
|
|
|
//extract a33
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11);//1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
|
|
|
|
//(ROw2): FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6]
|
|
|
|
//perform mul operation
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm1); //B11[0-3][3] /= A11[3][3]
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm1); //B11[0-3][7] /= A11[3][3]
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3]
|
|
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3]
|
|
}
|
|
|
|
if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4)
|
|
|
|
int iter;
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_t[iter] = (b11 + cs_b * 7)[iter];
|
|
f_temp = f_t;
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * 7);
|
|
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
///GEMM code Begins///
|
|
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] )
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
|
|
|
|
b01 += 1; //move to next row of B01
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0]
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1]
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2]
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3]
|
|
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4]
|
|
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5]
|
|
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6]
|
|
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7]
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
///implement TRSM///
|
|
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//broadcast diagonal elements of A11
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+cs_a*2 + 2)); //A11[2][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm3, ymm0); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
|
|
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a00
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
|
|
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1]
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1]
|
|
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//extract a22
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(ROw2): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
|
|
|
|
//perform mul operation
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm1); //B11[0-3][2] /=A11[2][2]
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm1); //B11[0-3][6] /= A11[2][2]
|
|
|
|
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
|
|
ymm15 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
|
|
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
|
|
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08);
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08);
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08);
|
|
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08);
|
|
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08);
|
|
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
///implement TRSM///
|
|
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//broadcast diagonal elements of A11
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a +1)); //A11[1][1]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm1, ymm2); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm0, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm0, ymm5); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a00
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
|
|
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm1); //B11[0-3][1] /= A11[1][1]
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm1); //B11[0-3][5] /= A11[1][1]
|
|
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
|
|
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30);
|
|
ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30);
|
|
ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30);
|
|
ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30);
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30);
|
|
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
///implement TRSM///
|
|
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//broadcast diagonal elements of A11
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
ymm0 = _mm256_div_pd(ymm0, ymm1); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a00
|
|
ymm1 = _mm256_permute_pd(ymm0, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm1, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
|
|
//(Row 0): perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm1); //B11[0-3][0] /= A11[0][0]
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm1); //B11[0-3][4] /= A11[0][0]
|
|
|
|
ymm9 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm9); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm9, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm9, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm9, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm9, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm9); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm9, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm9, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
|
|
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E);
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E);
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E);
|
|
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E);
|
|
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7])
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * 7)[iter] = f_t[iter];
|
|
}
|
|
}
|
|
}
|
|
|
|
if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4)
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4)
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[2][2] A11[2][2]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[1][1] A11[1][1] A11[3][3] A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//extract a00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
|
|
|
|
//extract diag a11 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3]
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
|
|
|
|
|
|
//extract diag a22 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2]
|
|
|
|
//extract diag a33 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3])
|
|
|
|
}
|
|
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+4) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * 3)[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * 3);
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++) //looop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7
|
|
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//extract a00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
|
|
|
|
//extract diag a11 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
|
|
|
|
//extract diag a22 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2]
|
|
|
|
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//load 4x4 block from b11
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
else if( 2 == m_remainder )
|
|
{
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//extract a00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
|
|
|
|
//extract diag a11 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[][] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
|
|
|
|
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
|
|
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//load 4x4 block from b11
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30);
|
|
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//extract a00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00);//1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
|
|
|
|
ymm8 = _mm256_broadcast_sd((double const *)(&ones));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
|
|
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//load 4x4 block from b11
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3])
|
|
|
|
if((j+4) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * 3)[iter] = f_temp[iter];
|
|
}
|
|
}
|
|
|
|
n_remainder -= 4;
|
|
j += 4;
|
|
|
|
}
|
|
|
|
if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
|
|
|
|
}
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11 + 1)); //A11[1][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][3]
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[0][0] A11[1][1] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[2][2] A11[3][3] A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); //1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//extract a00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0]
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
ymm4 = _mm256_mul_pd(ymm4, ymm15); //B11[0][0-3] /= A11[0][0]
|
|
|
|
//extract diag a11 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3]
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm15); //B11[1][0-3] /= A11[1][1]
|
|
|
|
|
|
//extract diag a22 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm15); //B11[2][0-3] /= A11[2][2]
|
|
|
|
//extract diag a33 from a
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3]
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm15); //B11[3][0-3] /= A11[3][3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
if(3 == n_remainder)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
}
|
|
|
|
}
|
|
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
|
|
k_iter = i / D_MR; //number of times GEMM operations to be performed
|
|
|
|
dim_t iter;
|
|
if((j+n_remainder) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * (n_remainder -1));
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
|
|
//load 4x4 block from b11
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
|
|
ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5
|
|
ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6
|
|
|
|
///implement TRSM///
|
|
//determine correct values to store
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2])
|
|
}
|
|
if(2 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
|
|
ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5
|
|
|
|
///implement TRSM///
|
|
//determine correct values to store
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1])
|
|
}
|
|
if(n_remainder == 1)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
|
|
|
|
///implement TRSM///
|
|
//determine correct values to store
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0])
|
|
}
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
|
|
}
|
|
|
|
///scalar code for trsm without alpha///
|
|
dtrsm_small_AlXB(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/* TRSM for the case AX = alpha * B, Double precision
|
|
* A is lower-triangular, no-transpose, unit diagonal
|
|
* dimensions A: mxm X: mxn B: mxn
|
|
|
|
b01--->
|
|
* *****************
|
|
** * * * * *
|
|
* * * * * * *
|
|
* * *b01* * * *
|
|
* * * * * * *
|
|
a10 ****** b11 *****************
|
|
| * * * | * * * * *
|
|
| * * * | * * * * *
|
|
| *a10*a11* | *b11* * * *
|
|
v * * * v * * * * *
|
|
*********** *****************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
**************** *****************
|
|
a11--->
|
|
*/
|
|
|
|
static err_t bli_dtrsm_small_AlXB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
|
|
dim_t D_MR = 4; //size of block along 'M' dimpension
|
|
dim_t D_NR = 8; //size of block along 'N' dimension
|
|
|
|
dim_t m = bli_obj_length(b); // number of rows of matrix B
|
|
dim_t n = bli_obj_width(b); // number of columns of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME)
|
|
|| (m> D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N)
|
|
|| (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_M && n<D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_ROME_COLUMN_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_ALXB_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t m_remainder = m & (3); //number of remainder rows
|
|
dim_t n_remainder = n & (7); //number of remainder columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); // column stride of A
|
|
dim_t cs_b = bli_obj_col_stride(b); // column stride of B
|
|
|
|
dim_t i, j, k; //loop variables
|
|
dim_t k_iter; //number of times GEMM to be performed
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a10, *a11, *b01, *b11; //pointers that point to blocks for GEMM and TRSM
|
|
double *ptr_b01_dup;
|
|
|
|
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
|
|
double* f_temp;
|
|
|
|
double ones = 1.0;
|
|
|
|
//scratch registers
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
|
|
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' dimension
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' dimension
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM to be performed(in blocks of 4x4)
|
|
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
///GEMM code begins///
|
|
|
|
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10));//A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a));//A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2));//A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][4]*A10[3][2])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][5]*A10[3][2])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][6]*A10[3][2])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][7]*A10[3][2])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3));//A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
|
|
|
|
b01 += 1; //mobe to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[3][0] B01[3][0]*A10[2][3] B01[3][0]*A10[3][0])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[3][1]*A10[0][3] B01[3][1]*A10[3][0] B01[3][1]*A10[2][3] B01[3][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[3][2]*A10[0][3] B01[3][2]*A10[3][0] B01[3][2]*A10[2][3] B01[3][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[3][3]*A10[0][3] B01[3][3]*A10[3][0] B01[3][3]*A10[2][3] B01[3][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[3][4]*A10[0][3] B01[3][4]*A10[3][0] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[3][5]*A10[0][3] B01[3][5]*A10[3][0] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[3][6]*A10[0][3] B01[3][6]*A10[3][0] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[3][7]*A10[0][3] B01[3][7]*A10[3][0] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to calculate next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to calculate next block of B for GEMM
|
|
}
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha
|
|
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b *7)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] * alpha -= B01[0-3][0]
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] * alpha -= B01[0-3][1]
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] * alpha -= B01[0-3][2]
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] * alpha -= B01[0-3][3]
|
|
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] * alpha -= B01[0-3][4]
|
|
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] * alpha -= B01[0-3][5]
|
|
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] * alpha -= B01[0-3][6]
|
|
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] * alpha -= B01[0-3][7]
|
|
|
|
///implement TRSM///
|
|
|
|
///transpose of B11//
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[2][4] B11[2][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[2][6] B11[2][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[1][4] B11[1][5] B11[3][4] B11[3][5]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[1][6] B11[1][7] B11[3][6] B11[3][7]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= A11[1][0] * B11[0-3][0]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= A11[2][0] * B11[0-3][0]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); //B11[3][0-3] -= A11[3][0] * B11[0-3][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= A11[1][0] * B11[0-3][4]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= A11[2][0] * B11[0-3][4]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); //B11[7][0-3] -= A11[3][0] * B11[0-3][4]
|
|
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(ROw2): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); //B11[3][0-3] -= A11[3][1] * B11[0-3][1]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); //B11[7][0-3] -= A11[3][1] * B11[0-3][5]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11 +3)); //A11[3][2]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(ROw1): FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); //B11[3][0-3] -= A11[3][2] * B11[0-3][2]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); //B11[7][0-3] -= A11[3][2] * B11[0-3][6]
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[4][1] B11[5][1] B11[4][3] B11[5][3]
|
|
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[6][1] B11[7][1] B11[6][3] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store B11[0][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store B11[1][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store B11[2][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store B11[3][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store B11[4][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store B11[5][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store B11[6][0-3]
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); //store B11[7][0-3]
|
|
}
|
|
|
|
if(m_remainder) //implementation for reamainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM operation to be done(in blocks of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * 7)[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * 7);
|
|
|
|
ymm8 = _mm256_setzero_pd();
|
|
ymm9 = _mm256_setzero_pd();
|
|
ymm10 = _mm256_setzero_pd();
|
|
ymm11 = _mm256_setzero_pd();
|
|
ymm12 = _mm256_setzero_pd();
|
|
ymm13 = _mm256_setzero_pd();
|
|
ymm14 = _mm256_setzero_pd();
|
|
ymm15 = _mm256_setzero_pd();
|
|
|
|
///GEMM code Begins///
|
|
for(k = 0; k< k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[0][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[0][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[0][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[0][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0] )
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[0][4]*A10[0][0] B01[0][4]*A10[1][0] B01[0][4]*A10[2][0] B01[0][4]*A10[3][0])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[0][5]*A10[0][0] B01[0][5]*A10[1][0] B01[0][5]*A10[2][0] B01[0][5]*A10[3][0])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[0][6]*A10[0][0] B01[0][6]*A10[1][0] B01[0][6]*A10[2][0] B01[0][6]*A10[3][0])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm16 += (B01[0][7]*A10[0][0] B01[0][7]*A10[1][0] B01[0][7]*A10[2][0] B01[0][7]*A10[3][0])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 1)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[1][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[1][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[1][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[1][7]
|
|
|
|
b01 += 1; //move to next row of B01
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[1][4]*A10[0][1] B01[1][4]*A10[1][1] B01[1][4]*A10[2][1] B01[1][4]*A10[3][1])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[1][5]*A10[0][1] B01[1][5]*A10[1][1] B01[1][5]*A10[2][1] B01[1][5]*A10[3][1])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[1][6]*A10[0][1] B01[1][6]*A10[1][1] B01[1][6]*A10[2][1] B01[1][6]*A10[3][1])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[1][7]*A10[0][1] B01[1][7]*A10[1][1] B01[1][7]*A10[2][1] B01[1][7]*A10[3][1])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] //A10[1][2] A10[2][2] A10[3][2]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[2][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[2][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[2][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[2][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm9 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm10 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm11 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm12 += (B01[2][4]*A10[0][2] B01[2][4]*A10[1][2] B01[2][4]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm13 += (B01[2][5]*A10[0][2] B01[2][5]*A10[1][2] B01[2][5]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm14 += (B01[2][6]*A10[0][2] B01[2][6]*A10[1][2] B01[2][6]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm15 += (B01[2][7]*A10[0][2] B01[2][7]*A10[1][2] B01[2][7]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm16 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm4 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
ymm0 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 4)); //B01[3][4]
|
|
ymm1 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 5)); //B01[3][5]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 6)); //B01[3][6]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 7)); //B01[3][7]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm8 = _mm256_fmadd_pd(ymm4, ymm16, ymm8); //ymm8 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm9 = _mm256_fmadd_pd(ymm5, ymm16, ymm9); //ymm8 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm10 = _mm256_fmadd_pd(ymm6, ymm16, ymm10); //ymm8 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm11 = _mm256_fmadd_pd(ymm7, ymm16, ymm11); //ymm8 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
ymm12 = _mm256_fmadd_pd(ymm0, ymm16, ymm12); //ymm8 += (B01[3][0]*A10[0][3] B01[3][4]*A10[1][3] B01[3][4]*A10[2][3] B01[3][4]*A10[3][3])
|
|
ymm13 = _mm256_fmadd_pd(ymm1, ymm16, ymm13); //ymm8 += (B01[3][1]*A10[0][3] B01[3][5]*A10[1][3] B01[3][5]*A10[2][3] B01[3][5]*A10[3][3])
|
|
ymm14 = _mm256_fmadd_pd(ymm2, ymm16, ymm14); //ymm8 += (B01[3][2]*A10[0][3] B01[3][6]*A10[1][3] B01[3][6]*A10[2][3] B01[3][6]*A10[3][3])
|
|
ymm15 = _mm256_fmadd_pd(ymm3, ymm16, ymm15); //ymm8 += (B01[3][3]*A10[0][3] B01[3][7]*A10[1][3] B01[3][7]*A10[2][3] B01[3][7]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *0)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *1)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b *2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b *3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b *6)); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); //B11[0-3][0] *alpha -= B01[0-3][0]
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); //B11[0-3][1] *alpha -= B01[0-3][1]
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); //B11[0-3][2] *alpha -= B01[0-3][2]
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); //B11[0-3][3] *alpha -= B01[0-3][3]
|
|
ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); //B11[0-3][4] *alpha -= B01[0-3][4]
|
|
ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); //B11[0-3][5] *alpha -= B01[0-3][5]
|
|
ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); //B11[0-3][6] *alpha -= B01[0-3][6]
|
|
ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); //B11[0-3][7] *alpha -= B01[0-3][7]
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
///implement TRSM///
|
|
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[2][0-3] -= B11[0-3][0]*A11[2][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[6][0-3] -= B11[0-3][4]*A11[2][4]
|
|
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11 +2)); //A11[2][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(ROw2): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); //B11[2][0-3] -= A11[2][1] * B11[0-3][1]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); //B11[6][0-3] -= A11[2][1] * B11[0-3][5]
|
|
|
|
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
|
|
ymm15 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); //B11[6][0] B11[7][0] B11[6][2] B11[7][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
|
|
ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); //B11[2][5] B11[3][5] B11[2][7] B11[3][7]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
|
|
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x08);
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x08);
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x08);
|
|
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x08);
|
|
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x08);
|
|
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
///implement TRSM///
|
|
|
|
///unpacklow///
|
|
ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); //B11[0][4] B11[0][5] B11[1][4] B11[1][5]
|
|
ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); //B11[0][6] B11[0][7] B11[1][6] B11[1][7]
|
|
|
|
//rearrange low elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm9,ymm11,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm10 = _mm256_permute2f128_pd(ymm9,ymm11,0x31); //B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
ymm12 = _mm256_permute2f128_pd(ymm13,ymm15,0x20); //B11[4][0] B11[4][1] B11[4][2] B11[4][3]
|
|
ymm14 = _mm256_permute2f128_pd(ymm13,ymm15,0x31); //B11[6][0] B11[6][1] B11[6][2] B11[6][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); //B11[5][0] B11[5][1] B11[7][0] B11[7][1]
|
|
ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); //B11[5][2] B11[5][3] B11[7][2] B11[7][3]
|
|
|
|
//rearrange high elements
|
|
ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
ymm13 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); //B11[5][0] B11[5][1] B11[5][2] B11[5][3]
|
|
ymm15 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); //B11[7][0] B11[7][1] B11[7][2] B11[7][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11 +1)); //A11[1][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); //B11[1][0-3] -= B11[0-3][0]*A11[1][0]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); //B11[5][0-3] -= B11[0-3][4]*A11[1][4]
|
|
|
|
ymm10 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//unpacklow//
|
|
ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
|
|
ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); //B11[4][0] B11[5][0] B11[4][2] B11[5][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1, ymm10, 0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1, ymm10, 0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
ymm4 = _mm256_permute2f128_pd(ymm5, ymm10, 0x20); //B11[0][4] B11[1][4] B11[2][4] B11[3][4]
|
|
ymm6 = _mm256_permute2f128_pd(ymm5, ymm10, 0x31); //B11[0][6] B11[1][6] B11[2][6] B11[3][6]
|
|
|
|
///unpack high///
|
|
ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); //B11[0][5] B11[1][5] B11[0][7] B11[1][7]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm8, ymm10, 0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm8, ymm10, 0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm5 = _mm256_permute2f128_pd(ymm12, ymm10, 0x20); //B11[0][5] B11[1][5] B11[2][5] B11[3][5]
|
|
ymm7 = _mm256_permute2f128_pd(ymm12, ymm10, 0x31); //B11[0][7] B11[1][7] B11[2][7] B11[3][7]
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
|
|
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm8, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm9, 0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2, ymm10, 0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3, ymm11, 0x30);
|
|
ymm4 = _mm256_permute2f128_pd(ymm4, ymm12, 0x30);
|
|
ymm5 = _mm256_permute2f128_pd(ymm5, ymm13, 0x30);
|
|
ymm6 = _mm256_permute2f128_pd(ymm6, ymm14, 0x30);
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm15, 0x30);
|
|
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm8 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); //load B11[0-3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); //load B11[0-3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //load B11[0-3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //load B11[0-3][3]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + cs_b * 4)); //load B11[0-3][4]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b * 5)); //load B11[0-3][5]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b * 6)); //load B11[0-3][6]
|
|
ymm15 = _mm256_loadu_pd((double const *)(f_temp)); //load B11[0-3][7]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm8, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm9, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm10, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm11, 0x0E);
|
|
ymm4 = _mm256_blend_pd(ymm4, ymm12, 0x0E);
|
|
ymm5 = _mm256_blend_pd(ymm5, ymm13, 0x0E);
|
|
ymm6 = _mm256_blend_pd(ymm6, ymm14, 0x0E);
|
|
ymm7 = _mm256_blend_pd(ymm7, ymm15, 0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); //store(B11[0-3][4])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); //store(B11[0-3][5])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); //store(B11[0-3][6])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm7); //store(B11[0-3][7])
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * 7)[iter] = f_temp[iter];
|
|
}
|
|
}
|
|
}
|
|
|
|
if((n & 4)) //implementation for remainder columns(when 'n_remainder' is greater than 4)
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of times GEMM to be performed(in block of 4)
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b*3)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a*2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[1][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[2][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[3][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B01[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B01[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B01[0-3][3] *alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]*B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]*B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0]*B11[0][0-3]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]*B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1]*B11[1][0-3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2]*B11[2][0-3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[0-3][3])
|
|
|
|
}
|
|
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+4) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * 3)[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * 3);
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
for(k = 0; k < k_iter; k++) //looop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[0][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B01[0][3]*A10[0][0] B01[0][3]*A10[1][0] B01[0][3]*A10[2][0] B01[0][3]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[1][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B01[1][3]*A10[0][1] B01[1][3]*A10[1][1] B01[1][3]*A10[2][1] B01[1][3]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[2][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B01[2][3]*A10[0][2] B01[2][3]*A10[1][2] B01[2][3]*A10[2][2] B01[2][3]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 3)); //B01[3][3]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B01[3][3]*A10[0][3] B01[3][3]*A10[1][3] B01[3][3]*A10[2][3] B01[3][3]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[0-3][3] *alpha -= ymm7
|
|
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0]* B11[0][0-3]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1]* B11[1][0-3]
|
|
|
|
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//load 4x4 block from b11
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0]* B11[0][0-3]
|
|
|
|
ymm11 = _mm256_broadcast_sd((double const *)(&ones));
|
|
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
//load 4x4 block from b11
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1, ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2, ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3, ymm7,0x30);
|
|
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
//load 4x4 block from b11
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm7 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][2] B11[3][3]
|
|
|
|
//determine correct values to store
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[0-3][3])
|
|
|
|
if((j+4) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * 3)[iter] = f_temp[iter];
|
|
}
|
|
}
|
|
|
|
n_remainder -= 4;
|
|
j += 4;
|
|
|
|
}
|
|
|
|
if(n_remainder) //implementation fo remaining columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
for(i = 0;i+D_MR-1 < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = i / D_MR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
//load 4x4 block from b11
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[0][2]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[1][2]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[2][2]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B01[3][2]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[0][1]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[1][1]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[2][1]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B01[3][1]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[0][0]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[1][0]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[2][0]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B01[3][0]
|
|
|
|
b01 += 1;
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha Value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] *alpha -= ymm4
|
|
ymm1 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][1] *alpha -= ymm5
|
|
ymm2 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][2] *alpha -= ymm6
|
|
ymm3 = _mm256_broadcast_sd((double const *)(&ones)); //B11[0-3][3] *alpha -= ymm7
|
|
|
|
}
|
|
|
|
///implement TRSM///
|
|
//1st col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11 + 2)); //A11[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11 + 3)); //A11[3][2]
|
|
|
|
////unpacklow////
|
|
ymm8 = _mm256_unpacklo_pd(ymm0, ymm1); //B11[0][0] B11[0][1] B11[2][0] B11[2][1]
|
|
ymm13 = _mm256_unpacklo_pd(ymm2, ymm3); //B11[0][2] B11[0][3] B11[2][2] B11[2][3]
|
|
|
|
//rearrange low elements
|
|
ymm4 = _mm256_permute2f128_pd(ymm8,ymm13,0x20); //B11[0][0] B11[0][1] B11[0][2] B11[0][3]
|
|
ymm11 = _mm256_permute2f128_pd(ymm8,ymm13,0x31);//B11[2][0] B11[2][1] B11[2][2] B11[2][3]
|
|
|
|
////unpackhigh////
|
|
ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); //B11[1][0] B11[1][1] B11[3][0] B11[3][1]
|
|
ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); //B11[1][2] B11[1][3] B11[3][2] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm8 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); //B11[1][0] B11[1][1] B11[1][2] B11[1][3]
|
|
ymm13 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); //B11[3][0] B11[3][1] B11[3][2] B11[3][3]
|
|
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (3, 0)
|
|
ymm8 = _mm256_fnmadd_pd(ymm5, ymm4, ymm8);//d = c - (a*b) //B11[1][0-3] -= A11[1][0] * B11[0][0-3]
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm4, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][0] * B11[0][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm7, ymm4, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][0] * B11[0][0-3]
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
ymm11 = _mm256_fnmadd_pd(ymm9, ymm8, ymm11);//d = c - (a*b) //B11[2][0-3] -= A11[2][1] * B11[1][0-3]
|
|
ymm13 = _mm256_fnmadd_pd(ymm10, ymm8, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][1] * B11[1][0-3]
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
ymm13 = _mm256_fnmadd_pd(ymm12, ymm11, ymm13);//d = c - (a*b) //B11[3][0-3] -= A11[3][2] * B11[2][0-3]
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
ymm1 = _mm256_unpacklo_pd(ymm4, ymm8); //B11[0][0] B11[1][0] B11[0][2] B11[1][2]
|
|
ymm3 = _mm256_unpacklo_pd(ymm11, ymm13); //B11[2][0] B11[3][0] B11[2][2] B11[3][2]
|
|
|
|
//rearrange low elements
|
|
ymm0 = _mm256_permute2f128_pd(ymm1,ymm3,0x20); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_permute2f128_pd(ymm1,ymm3,0x31); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
////unpackhigh////
|
|
ymm14 = _mm256_unpackhi_pd(ymm4, ymm8); //B11[0][1] B11[1][1] B11[0][3] B11[1][3]
|
|
|
|
ymm15 = _mm256_unpackhi_pd(ymm11, ymm13); //B11[2][1] B11[3][1] B11[2][3] B11[3][3]
|
|
|
|
//rearrange high elements
|
|
ymm1 = _mm256_permute2f128_pd(ymm14,ymm15,0x20); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_permute2f128_pd(ymm14,ymm15,0x31); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
if(3 == n_remainder)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2])
|
|
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
}
|
|
|
|
}
|
|
if(m_remainder) //implementation for remainder rows(when 'M' is not a multiple of D_MR)
|
|
{
|
|
a10 = L +i; //pointer to block of A to be used for GEMM
|
|
a11 = L + i + (i*cs_a); //pointer to block of A to be used for TRSM
|
|
b01 = B + j*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j* cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
|
|
k_iter = i / D_MR; //number of times GEMM operations to be performed
|
|
|
|
dim_t iter;
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * (n_remainder -1))[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * (n_remainder -1));
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM for previously calculated values ///
|
|
|
|
|
|
//load 4x4 block from b11
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[0][2]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B01[0][2]*A10[0][0] B01[0][2]*A10[1][0] B01[0][2]*A10[2][0] B01[0][2]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[1][2]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B01[1][2]*A10[0][1] B01[1][2]*A10[1][1] B01[1][2]*A10[2][1] B01[1][2]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[2][2]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B01[2][2]*A10[0][2] B01[2][2]*A10[1][2] B01[2][2]*A10[2][2] B01[2][2]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 2)); //B10[3][2]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B01[3][2]*A10[0][3] B01[3][2]*A10[1][3] B01[3][2]*A10[2][3] B01[3][2]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
|
|
ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5
|
|
ymm10 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[0-3][2] * alpha -= ymm6
|
|
|
|
///implement TRSM///
|
|
//determine correct values to store
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm10, ymm2, 0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm10, ymm2, 0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + (cs_b)), ymm1); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm2); //store(B11[0-3][2])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(b11)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[0][1]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B01[0][1]*A10[0][0] B01[0][1]*A10[1][0] B01[0][1]*A10[2][0] B01[0][1]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[1][1]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B01[1][1]*A10[0][1] B01[1][1]*A10[1][1] B01[1][1]*A10[2][1] B01[1][1]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[2][1]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B01[2][1]*A10[0][2] B01[2][1]*A10[1][2] B01[2][1]*A10[2][2] B01[2][1]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 1)); //B10[3][1]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B01[3][1]*A10[0][3] B01[3][1]*A10[1][3] B01[3][1]*A10[2][3] B01[3][1]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
|
|
ymm9 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[0-3][1] * alpha -= ymm5
|
|
|
|
///implement TRSM///
|
|
//determine correct values to store
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm9, ymm1, 0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[0-3][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_b01_dup = b01;
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)(a10)); //A10[0][0] A10[1][0] A10[2][0] A10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(a10 + cs_a)); //A10[0][1] A10[1][1] A10[2][1] A10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); //A10[0][2] A10[1][2] A10[2][2] A10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); //A10[0][3] A10[1][3] A10[2][3] A10[3][3]
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[0][0]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B01[0][0]*A10[0][0] B01[0][0]*A10[1][0] B01[0][0]*A10[2][0] B01[0][0]*A10[3][0])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[1][0]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B01[1][0]*A10[0][1] B01[1][0]*A10[1][1] B01[1][0]*A10[2][1] B01[1][0]*A10[3][1])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[2][0]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B01[2][0]*A10[0][2] B01[2][0]*A10[1][2] B01[2][0]*A10[2][2] B01[2][0]*A10[3][2])
|
|
|
|
ymm12 = _mm256_broadcast_sd((double const *)(b01 + cs_b * 0)); //B10[3][0]
|
|
|
|
b01 += 1; //move to next row of B
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B01[3][0]*A10[0][3] B01[3][0]*A10[1][3] B01[3][0]*A10[2][3] B01[3][0]*A10[3][3])
|
|
|
|
a10 += D_MR * cs_a; //pointer math to find next block of A for GEMM
|
|
b01 = ptr_b01_dup + D_MR; //pointer math to find next block of B for GEMM
|
|
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to hold alpha value
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[0-3][0] * alpha -= ymm4
|
|
|
|
///implement TRSM///
|
|
//determine correct values to store
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm8, ymm0, 0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm8, ymm0, 0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)(f_temp), ymm0); //store(B11[0-3][0])
|
|
}
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
|
|
}
|
|
///scalar code for trsm without alpha///
|
|
dtrsm_small_AlXB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is upper triangular, non-unit diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* b11---> a01 ---->
|
|
***************** ***********
|
|
*b01*b11* * * * * * *
|
|
b11 * * * * * **a01 * * a11
|
|
| ***************** ********* |
|
|
| * * * * * *a11* * |
|
|
| * * * * * * * * |
|
|
v ***************** ****** v
|
|
* * * * * * *
|
|
* * * * * * *
|
|
***************** * *
|
|
*
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAuB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
|
|
dim_t m_remainder = m & 7; //number of corner rows
|
|
dim_t n_remainder = n & 3; //number of corner columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME)
|
|
|| (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
|
|
double* f_temp;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used in GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3]
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] 1 1
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/1 1/1)
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
|
|
{
|
|
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM for previous blocks ///
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));//A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm14); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM for previous blocks ///
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));//A11[1][1]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM for previous blocks ///
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
}
|
|
m_remainder -= 4;
|
|
i += 4;
|
|
}
|
|
if(m_remainder) ///omplementation for remainder rows
|
|
{
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * 3)[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * 3);
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM implementation stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM implementation ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[3][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11));
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm7 = _mm256_loadu_pd((double const *)f_temp);
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3])
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * 3)[iter] = f_temp[iter];
|
|
}
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * (n_remainder-1));
|
|
///GEMM for previous blocks ///
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///load 4x4 block of b11
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM implementation starts///
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)f_temp, ymm2); //(store(B11[x][2]))
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM implementation starts///
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)f_temp); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM implementation starts///
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0])
|
|
}
|
|
if((j+n_remainder) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
|
|
}
|
|
//scalar code for TRSM
|
|
dtrsm_small_XAuB(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is upper triangular, unit-diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* b11---> a01 ---->
|
|
***************** ***********
|
|
*b01*b11* * * * * * *
|
|
b11 * * * * * **a01 * * a11
|
|
| ***************** ********* |
|
|
| * * * * * *a11* * |
|
|
| * * * * * * * * |
|
|
v ***************** ****** v
|
|
* * * * * * *
|
|
* * * * * * *
|
|
***************** * *
|
|
*
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAuB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
|
|
dim_t m_remainder = m & 7; //number of corner rows
|
|
dim_t n_remainder = n & 3; //number of corner columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME)
|
|
|| (m>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUB_ROME_COLUMN_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n)>D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES && (m/n) < D_BLIS_SMALL_MATRIX_THRES_TRSM_DIM_RATIO)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
|
|
double* f_temp;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used in GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
|
|
{
|
|
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM for previous blocks ///
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM for previous blocks ///
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM for previous blocks ///
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
}
|
|
m_remainder -= 4;
|
|
i += 4;
|
|
}
|
|
if(m_remainder) ///omplementation for remainder rows
|
|
{
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b_offset[1])[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b_offset[1]);
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM implementation stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM implementation ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][2]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][3]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[1][3]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[2][3]
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11));
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3])
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b_offset[1])[iter] = f_temp[iter];
|
|
}
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * (n_remainder-1));
|
|
|
|
///GEMM for previous blocks ///
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///load 4x4 block of b11
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM implementation starts///
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)f_temp, ymm2); //(store(B11[x][2]))
|
|
}
|
|
if(2 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM implementation starts///
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1])
|
|
}
|
|
if(1 == n_remainder)
|
|
{
|
|
ymm0 = _mm256_loadu_pd((double const *)f_temp); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM implementation starts///
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0])
|
|
}
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
|
|
}
|
|
//scalar code for TRSM
|
|
dtrsm_small_XAuB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, non-unit diagonal, transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* b11---> a01 ---->
|
|
***************** ***********
|
|
*b01*b11* * * * * * *
|
|
b11 * * * * * **a01 * * a11
|
|
| ***************** ********* |
|
|
| * * * * * *a11* * |
|
|
| * * * * * * * * |
|
|
v ***************** ****** v
|
|
* * * * * * *
|
|
* * * * * * *
|
|
***************** * *
|
|
*
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAltB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
|
|
dim_t m_remainder = m & 7; //number of corner rows
|
|
dim_t n_remainder = n & 3; //number of corner columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N)
|
|
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N)
|
|
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME)
|
|
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME)
|
|
|| (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
|
|
double* f_temp;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used in GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0); //B11[0-3][3] /= A11[3][3]
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0); //B11[4-7][3] /= A11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(&ones)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0); //B11[0-3][2] /= A11[2][2]
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0); //B11[4-7][2] /= A11[2][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)b11);
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR));
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR));
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm7, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0); //B11[0-3][1] /= A11[1][1]
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0); //B11[4-7][1] /= A11[1][1]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
|
|
///implement TRSM///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm7); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm7); //B11[4-7][0] /= A11[0][0]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
|
|
{
|
|
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//4th row
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM for previous blocks ///
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//4th row
|
|
ymm13 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM for previous blocks ///
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm14, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM for previous blocks ///
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
ymm14 = _mm256_div_pd(ymm14, ymm4); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm14); //B11[x][0] /= A11[0][0]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i += 4;
|
|
}
|
|
if(m_remainder) ///omplementation for remainder rows
|
|
{
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b_offset[1])[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b_offset[1]);
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM implementation stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM implementation ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
//1st row
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0));
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//4th row
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm6); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm6 = _mm256_unpacklo_pd(ymm9, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm6, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15); //B11[x][0] /= A11[0][0]
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15); //B11[x][1] /= A11[1][1]
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15); //B11[x][2] /= A11[2][2]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15); //B11[x][3] /= A11[3][3]
|
|
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11));
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3])
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b_offset[1])[iter] = f_temp[iter];
|
|
}
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * (n_remainder-1));
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM for previous blocks ///
|
|
|
|
if(3 == n_remainder)
|
|
{
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm2); //(store(B11[x][2]))
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0])
|
|
}
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
|
|
}
|
|
//scalar code for TRSM
|
|
dtrsm_small_XAltB(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, unit-diagonal, transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* b11---> a01 ---->
|
|
***************** ***********
|
|
*b01*b11* * * * * * *
|
|
b11 * * * * * **a01 * * a11
|
|
| ***************** ********* |
|
|
| * * * * * *a11* * |
|
|
| * * * * * * * * |
|
|
v ***************** ****** v
|
|
* * * * * * *
|
|
* * * * * * *
|
|
***************** * *
|
|
*
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAltB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
|
|
dim_t m_remainder = m & 7; //number of corner rows
|
|
dim_t n_remainder = n & 3; //number of corner columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_M && n>D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_ROW_PANEL_N)
|
|
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_M && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_SQUARE_N)
|
|
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME)
|
|
|| (m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME)
|
|
|| (m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALTB_ROME_COLUMN_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double *L = a->buffer; //pointer to matrix A
|
|
double *B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
double f_t[4] __attribute__((aligned(64)));//buffer to store corner column when m_remainder !=0
|
|
double* f_temp;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = 0; (i+D_MR-1) < m; i += D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = 0; (j+D_NR-1) < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used in GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm2, ymm8, ymm11); //B11[0-3][3] -= B11[0-3][0] * A11[0][3]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm2, ymm12, ymm15); //B11[4-7][3] -= B11[4-7][0] * A11[0][3]
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
ymm11 = _mm256_fnmadd_pd(ymm5, ymm9, ymm11); //B11[0-3][3] -= B11[0-3][1] * A11[1][3]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
ymm15 = _mm256_fnmadd_pd(ymm5, ymm13, ymm15); //B11[4-7][3] -= B11[4-7][1] * A11[1][3]
|
|
|
|
//(Row3)FMA operations
|
|
ymm11 = _mm256_fnmadd_pd(ymm6, ymm10, ymm11); //B11[0-3][3] -= B11[0-3][2] * A11[2][3]
|
|
|
|
ymm15 = _mm256_fnmadd_pd(ymm6, ymm14, ymm15); //B11[4-7][3] -= B11[4-7][2] * A11[2][3]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0-3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][2]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm15, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm15, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); //B11[0-3][2] -= B11[0-3][0] * A11[0][2]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); //B11[4-7][2] -= B11[4-7][0] * A11[0][2]
|
|
|
|
//(Row2)FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm4, ymm9, ymm10); //B11[0-3][2] -= B11[0-3][1] * A11[1][2]
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm4, ymm13, ymm14); //B11[4-7][2] -= B11[4-7][1] * A11[1][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)b11);
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR));
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR));
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm15, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm15, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
|
|
//(Row1): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm1, ymm8, ymm9); //B11[0-3][1] -= B11[0-3][0] * A11[0][1]
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm1, ymm12, ymm13); //B11[4-7][1] -= B11[4-7][0] * A11[0][1]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0-3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4-7][0]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm15, ymm0); //B11[0-3][0] * alpha -= B10[0-3][0]
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm15, ymm4); //B11[0-3][2] * alpha -= B10[0-3][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
|
|
{
|
|
for(j = 0; (j+D_NR-1)<n; j +=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
}
|
|
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM for previous blocks ///
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM for previous blocks ///
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st row
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM for previous blocks ///
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -= ymm4
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
}
|
|
|
|
}
|
|
m_remainder -= 4;
|
|
i += 4;
|
|
}
|
|
if(m_remainder) ///omplementation for remainder rows
|
|
{
|
|
for(j = 0; j+D_NR-1 < n; j += D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of time GEMM to be performed(in blocks of 4x4)
|
|
|
|
dim_t iter;
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
f_temp = f_t;
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b_offset[1])[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b_offset[1]);
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM implementation stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM implementation ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
//1st row
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1));
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//2nd row
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+2));
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//3rd row
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3));
|
|
|
|
a11 += cs_a;//move to next column
|
|
|
|
//(Row1): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm5, ymm0, ymm1); //B11[x][1] -= A11[0][1] * B11[x][0]
|
|
ymm2 = _mm256_fnmadd_pd(ymm7, ymm0, ymm2); //B11[x][2] -= A11[0][2] * B11[x][0]
|
|
ymm3 = _mm256_fnmadd_pd(ymm10, ymm0, ymm3); //B11[x][3] -= A11[0][3] * B11[x][0]
|
|
|
|
//(Row2)FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm8, ymm1, ymm2); //B11[x][2] -= A11[1][2] * B11[x][1]
|
|
ymm3 = _mm256_fnmadd_pd(ymm11, ymm1, ymm3); //B11[x][3] -= A11[1][3] * B11[x][1]
|
|
|
|
//(Row3)FMA operations
|
|
ymm3 = _mm256_fnmadd_pd(ymm12, ymm2, ymm3); //B11[x][3] -= A11[2][3] * B11[x][2]
|
|
|
|
ymm4 = _mm256_loadu_pd((double const *)(b11));
|
|
ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b));
|
|
ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1]));
|
|
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm4, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm1, ymm5, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm2, ymm6, 0x08);
|
|
ymm3 = _mm256_blend_pd(ymm3, ymm7, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm0,ymm4,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm1,ymm5,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm2,ymm6,0x30);
|
|
ymm3 = _mm256_permute2f128_pd(ymm3,ymm7,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm0,ymm4,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm1,ymm5,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm2,ymm6,0x0E);
|
|
ymm3 = _mm256_blend_pd(ymm3,ymm7,0x0E);
|
|
}
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(f_temp), ymm3); //store(B11[x][3])
|
|
|
|
if((j+D_NR) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b_offset[1])[iter] = f_temp[iter];
|
|
}
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when 'N' is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j; //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = j / D_NR; //number of GEMM operations to be performed(in block of 4x4)
|
|
|
|
dim_t iter;
|
|
err_t r_val;
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
f_temp = bli_malloc_user(4 * sizeof(double), &r_val);
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
f_temp[iter] = (b11 + cs_b * (n_remainder-1))[iter];
|
|
}
|
|
else
|
|
f_temp = (b11 + cs_b * (n_remainder-1));
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
///GEMM for previous blocks ///
|
|
|
|
if(3 == n_remainder)
|
|
{
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
ymm6 = _mm256_fmsub_pd(ymm2, ymm15, ymm6);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
ymm2 = _mm256_blend_pd(ymm6, ymm2, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
ymm2 = _mm256_permute2f128_pd(ymm6,ymm2,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
ymm2 = _mm256_blend_pd(ymm6,ymm2,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm2); //(store(B11[x][2]))
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
ymm5 = _mm256_fmsub_pd(ymm1, ymm15, ymm5);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
ymm1 = _mm256_blend_pd(ymm5, ymm1, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
ymm1 = _mm256_permute2f128_pd(ymm5,ymm1,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
ymm1 = _mm256_blend_pd(ymm5,ymm1,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(f_temp), ymm1); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)(f_temp)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
///GEMM implementation ends
|
|
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); ///register to store alpha
|
|
|
|
ymm4 = _mm256_fmsub_pd(ymm0, ymm15, ymm4);
|
|
///implement TRSM///
|
|
if(3 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4, ymm0, 0x08);
|
|
}
|
|
else if(2 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_permute2f128_pd(ymm4,ymm0,0x30);
|
|
}
|
|
else if(1 == m_remainder)
|
|
{
|
|
ymm0 = _mm256_blend_pd(ymm4,ymm0,0x0E);
|
|
}
|
|
_mm256_storeu_pd((double *)f_temp, ymm0); //store(B11[x][0])
|
|
}
|
|
|
|
if((j+n_remainder) == n)
|
|
{
|
|
for(iter = 0; iter < m_remainder; iter++)
|
|
(b11 + cs_b * (n_remainder-1))[iter] = f_temp[iter];
|
|
}
|
|
//scalar code for TRSM
|
|
dtrsm_small_XAltB_unitDiag(a11, b11, m_remainder, n_remainder, cs_a, cs_b);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, non-unit diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* <---b11 <---a11
|
|
***************** *
|
|
*b01*b11* * * * *
|
|
^ * * * * * ^ * *
|
|
| ***************** | *******
|
|
| * * * * * | * * *
|
|
| * * * * * a01* * *
|
|
b10 ***************** *************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
***************** *******************
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAlB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
|
|
dim_t m_remainder = m & 7; //number of corner rows
|
|
dim_t n_remainder = n & 3; //number of corner columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME)
|
|
||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double* restrict L = a->buffer; //pointer to matrix A
|
|
double* restrict B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
|
|
//extract a00
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1): FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm0); //B11[0-3][0] /= A11[0][0]
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm0); //B11[4-7][0] /= A11[0][0]
|
|
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
|
|
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
|
|
//extract a11
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x00);//(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(Row 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm0);
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm0);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
|
|
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//3rd col
|
|
a11 += 2;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[1][1] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm7 = _mm256_div_pd(ymm7, ymm0); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//extract a33
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
//extract a22
|
|
ymm0 = _mm256_permute_pd(ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm0 = _mm256_permute2f128_pd(ymm0, ymm0, 0x11);//(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm0);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm0);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
|
|
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)(&ones));
|
|
|
|
//4th col
|
|
a11 += 3;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 3)); //A11[3][3]
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
ymm7 = _mm256_div_pd(ymm7, ymm6); //(1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm7);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm7);
|
|
|
|
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if(i<0)
|
|
i += D_NR;
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15);
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM for previous blocks ///
|
|
if(3 == n_remainder)
|
|
{
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_broadcast_sd((double const *)&ones); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm14, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15);
|
|
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//3rd col
|
|
a11 += 2 * cs_a;
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][2]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//4th col
|
|
a11 += 3 * cs_a;
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][3]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm14);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
|
|
}
|
|
}
|
|
m_remainder -= 4;
|
|
i -= 4;
|
|
}
|
|
// if(i < 0) i = 0;
|
|
if(m_remainder) ///implementation for remainder rows
|
|
{
|
|
dtrsm_small_XAlB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, unit-diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* <---b11 <---a11
|
|
***************** *
|
|
*b01*b11* * * * *
|
|
^ * * * * * ^ * *
|
|
| ***************** | *******
|
|
| * * * * * | * * *
|
|
| * * * * * a01* * *
|
|
b10 ***************** *************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
***************** *******************
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAlB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
|
|
dim_t m_remainder = m & 7; //number of corner rows
|
|
dim_t n_remainder = n & 3; //number of corner columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME)
|
|
||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double* restrict L = a->buffer; //pointer to matrix A
|
|
double* restrict B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + j*cs_a +(j+D_NR); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//2nd col
|
|
a11 += 1;
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][1]
|
|
|
|
//3rd col
|
|
a11 += 1;
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][2]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 0)); //A11[0][3]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
//(Row 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
//(Row 1): FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[4-7][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b), ymm11); //store(B11[0-3][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + cs_b + D_NR), ymm15);//store(B11[4-7][3])
|
|
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j + D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///load 4x4 block of b11
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
//subtract the calculated GEMM block from current TRSM block
|
|
//load 8x4 block of B11
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] )); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1] + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
|
|
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha -= B10[4-7][0]
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= B10[4-7][2]
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//3rd col
|
|
a11 += 2;
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][2]
|
|
|
|
//4th col
|
|
a11 += 1;
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 1)); //A11[1][3]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
|
|
//(Row 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0-3][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
|
|
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha -= B10[0-3][1]
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= B10[0-3][3]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
//4th col
|
|
a11 += 3;
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+ cs_a * 2)); //A11[2][3]
|
|
|
|
//(row 3):FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM implementation begins///
|
|
|
|
for(k = 0; k < k_iter; k++) ///loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR));//B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
|
|
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= B10[4-7][1]
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= B10[4-7][3]
|
|
|
|
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if(i<0)
|
|
i += D_NR;
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][0]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][0]
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][0]
|
|
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
a01 = L + j*cs_a + (j+D_NR); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointwr to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM for previous blocks ///
|
|
if(3 == n_remainder)
|
|
{
|
|
///load 4x4 block of b11
|
|
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
//2nd col
|
|
a11 += cs_a;
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][1]
|
|
|
|
//3rd col
|
|
a11 += cs_a;
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//3rd col
|
|
a11 += 2 * cs_a;
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[1][2]
|
|
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
|
|
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM processing stars///
|
|
|
|
for(k = 0; k < k_iter; k++)
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[0][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[1][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[2][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + cs_a * 3)); //A01[3][3]
|
|
|
|
a01 += 1; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR; //pointer math to find next block of A for GEMM
|
|
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha value
|
|
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
|
|
}
|
|
}
|
|
m_remainder -= 4;
|
|
i -= 4;
|
|
}
|
|
// if(i < 0) i = 0;
|
|
if(m_remainder) ///implementation for remainder rows
|
|
{
|
|
dtrsm_small_XAlB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, non-unit diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* <---b11 <---a11
|
|
***************** *
|
|
*b01*b11* * * * *
|
|
^ * * * * * ^ * *
|
|
| ***************** | *******
|
|
| * * * * * | * * *
|
|
| * * * * * a01* * *
|
|
b10 ***************** *************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
***************** *******************
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAutB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
|
|
dim_t m_remainder = m & 7; //number of corner rows
|
|
dim_t n_remainder = n & 3; //number of corner columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME)
|
|
||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double ones = 1.0;
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double* restrict L = a->buffer; //pointer to matrix A
|
|
double* restrict B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
//extract a33
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm7);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm7);
|
|
|
|
//extract a22
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm7);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm7);
|
|
|
|
//extract a11
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm7);
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm7);
|
|
|
|
//extract A00
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1):FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
ymm8 = _mm256_mul_pd(ymm8, ymm7);
|
|
|
|
ymm12 = _mm256_mul_pd(ymm12, ymm7);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
|
|
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
//load 8x4 block of B11
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
|
|
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm0 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm0 = _mm256_unpacklo_pd(ymm0, ymm2); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm0, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
//extract a33
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm7);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm7);
|
|
|
|
//extract a22
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm7);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm7);
|
|
|
|
//extract a11
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
|
|
ymm9 = _mm256_mul_pd(ymm9, ymm7);
|
|
|
|
ymm13 = _mm256_mul_pd(ymm13, ymm7);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
|
|
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
a11 += 2 * cs_a;
|
|
|
|
//3rd col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm2 = _mm256_unpacklo_pd(ymm5, ymm6); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm0 = _mm256_blend_pd(ymm7, ymm2, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm0 = _mm256_div_pd(ymm7, ymm0); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
//extract a33
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm7);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm7);
|
|
|
|
//extract a22
|
|
ymm7 = _mm256_permute_pd(ymm0, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm7 = _mm256_permute2f128_pd(ymm7, ymm7, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
|
|
ymm10 = _mm256_mul_pd(ymm10, ymm7);
|
|
|
|
ymm14 = _mm256_mul_pd(ymm14, ymm7);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
|
|
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
a11 += 3 * cs_a;
|
|
|
|
//4th col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
ymm7 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
ymm0 = _mm256_div_pd(ymm7, ymm6); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
ymm11 = _mm256_mul_pd(ymm11, ymm0);
|
|
|
|
ymm15 = _mm256_mul_pd(ymm15, ymm0);
|
|
|
|
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if(i<0)
|
|
i += D_NR;
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15);
|
|
|
|
//extract A00
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[0][0] 1/A11[0][0])
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
ymm0 = _mm256_mul_pd(ymm0, ymm15);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
|
|
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
///GEMM for previous blocks ///
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///load 4x4 block of b11
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(&ones)); //A11[0][0]
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm4 = _mm256_unpacklo_pd(ymm4, ymm8); //A11[0][0] A11[1][1] A11[0][0] A11[1][1]
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm4, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
//extract a11
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x03); //(1/A11[1][1] 1/A11[1][1] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x00); //(1/A11[1][1] 1/A11[1][1] 1/A11[1][1] 1/A11[1][1])
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
|
|
ymm1 = _mm256_mul_pd(ymm1, ymm15);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
|
|
a11 += 2 * cs_a;
|
|
|
|
//3rd col
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm8 = _mm256_unpacklo_pd(ymm11, ymm13); //A11[2][2] A11[3][3] A11[2][2] A11[3][3]
|
|
|
|
ymm15 = _mm256_blend_pd(ymm14, ymm8, 0x0C); //A11[0][0] A11[1][1] A11[2][2] A11[3][3]
|
|
ymm14 = _mm256_div_pd(ymm14, ymm15); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
//extract a33
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x0C); //(1/A11[0][0] 1/A11[0][0] 1/A11[3][3] 1/A11[3][3])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[3][3] 1/A11[3][3] 1/A11[3][3] 1/A11[3][3])
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm15);
|
|
|
|
//extract a22
|
|
ymm15 = _mm256_permute_pd(ymm14, 0x00); //(1/A11[0][0] 1/A11[0][0] 1/A11[2][2] 1/A11[2][2])
|
|
ymm15 = _mm256_permute2f128_pd(ymm15, ymm15, 0x11); //(1/A11[2][2] 1/A11[2][2] 1/A11[2][2] 1/A11[2][2])
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
|
|
ymm2 = _mm256_mul_pd(ymm2, ymm15);
|
|
|
|
|
|
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
a11 += 3 * cs_a;
|
|
|
|
//4th col
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a11+3)); //A11[0][1]
|
|
|
|
ymm14 = _mm256_broadcast_sd((double const *)&ones);
|
|
|
|
//compute reciprocals of A(i,i) and broadcast in registers
|
|
ymm14 = _mm256_div_pd(ymm14, ymm13); // 1/A11[0][0] 1/A11[1][1] 1/A11[2][2] 1/A11[3][3]
|
|
|
|
ymm3 = _mm256_mul_pd(ymm3, ymm14);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
|
|
}
|
|
}
|
|
m_remainder -= 4;
|
|
i -= 4;
|
|
}
|
|
if(m_remainder) ///implementation for remainder rows
|
|
{
|
|
dtrsm_small_XAutB(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*implements TRSM for the case XA = alpha * B
|
|
*A is lower triangular, unit-diagonal, no transpose
|
|
*dimensions: X:mxn A:nxn B: mxn
|
|
*/
|
|
|
|
/* <---b11 <---a11
|
|
***************** *
|
|
*b01*b11* * * * *
|
|
^ * * * * * ^ * *
|
|
| ***************** | *******
|
|
| * * * * * | * * *
|
|
| * * * * * a01* * *
|
|
b10 ***************** *************
|
|
* * * * * * * * *
|
|
* * * * * * * * *
|
|
***************** *******************
|
|
|
|
*/
|
|
static err_t bli_dtrsm_small_XAutB_unitDiag(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
dim_t D_MR = 8; //block dimension along the rows
|
|
dim_t D_NR = 4; //block dimension along the columns
|
|
|
|
dim_t m = bli_obj_length(b); //number of rows
|
|
dim_t n = bli_obj_width(b); //number of columns
|
|
|
|
dim_t m_remainder = m & 7; //number of corner rows
|
|
dim_t n_remainder = n & 3; //number of corner columns
|
|
|
|
dim_t cs_a = bli_obj_col_stride(a); //column stride of matrix A
|
|
dim_t cs_b = bli_obj_col_stride(b); //column stride of matrix B
|
|
|
|
#ifdef BLIS_ENABLE_SMALL_MATRIX_ROME
|
|
if((m < D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME)
|
|
||(m > D_BLIS_SMALL_MATRIX_THRES_TRSM_XAUTB_ROME && n > D_BLIS_SMALL_MATRIX_THRES_TRSM_XALB_ROME_COL_PANEL_N)
|
|
)
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
#else
|
|
if(bli_max(m,n) > D_BLIS_SMALL_MATRIX_THRES_TRSM_NAPLES)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
#endif
|
|
|
|
dim_t i, j, k; //loop variablse
|
|
dim_t k_iter; //determines the number of GEMM operations to be done
|
|
dim_t cs_b_offset[2]; //pre-calculated strides
|
|
|
|
double AlphaVal = *(double *)AlphaObj->buffer; //value of Alpha
|
|
double* restrict L = a->buffer; //pointer to matrix A
|
|
double* restrict B = b->buffer; //pointer to matrix B
|
|
|
|
double *a01, *a11, *b10, *b11; //pointers for GEMM and TRSM blocks
|
|
double *ptr_a01_dup;
|
|
|
|
cs_b_offset[0] = cs_b << 1; //cs_b_offset[0] = cs_b * 2;
|
|
cs_b_offset[1] = cs_b_offset[0] + cs_b;//cs_b_offset[1] = cs_b * 3;
|
|
|
|
//ymm scratch reginsters
|
|
__m256d ymm0, ymm1, ymm2, ymm3;
|
|
__m256d ymm4, ymm5, ymm6, ymm7;
|
|
__m256d ymm8, ymm9, ymm10, ymm11;
|
|
__m256d ymm12, ymm13, ymm14, ymm15;
|
|
__m256d ymm16;
|
|
|
|
for(i = (m-D_MR); (i+1) > 0; i -= D_MR) //loop along 'M' direction
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -= D_NR) //loop along 'N' direction
|
|
{
|
|
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][0]*A01[0][0] B10[5][0]*A01[0][0] B10[6][0]*A01[0][0] B10[7][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][1]*A01[0][0] B10[1][1]*A01[0][0] B10[2][1]*A01[0][0] B10[3][1]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][1]*A01[0][0] B10[5][1]*A01[0][0] B10[6][1]*A01[0][0] B10[7][1]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm12, ymm0); //ymm0 += (B10[0][2]*A01[0][0] B10[1][2]*A01[0][0] B10[2][2]*A01[0][0] B10[3][2]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm13, ymm4); //ymm4 += (B10[4][2]*A01[0][0] B10[5][2]*A01[0][0] B10[6][2]*A01[0][0] B10[7][2]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm8 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm0 = _mm256_fmadd_pd(ymm8, ymm14, ymm0); //ymm0 += (B10[0][3]*A01[0][0] B10[1][3]*A01[0][0] B10[2][3]*A01[0][0] B10[3][3]*A01[0][0])
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm8, ymm15, ymm4); //ymm4 += (B10[4][3]*A01[0][0] B10[5][3]*A01[0][0] B10[6][3]*A01[0][0] B10[7][3]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
|
|
ymm16 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
//load 8x4 block of B11
|
|
ymm8 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm12 = _mm256_loadu_pd((double const *)(b11 + D_NR)); //B11[4][0] B11[5][0] B11[6][0] B11[7][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4][1] B11[5][1] B11[6][1] B11[7][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4][2] B11[5][2] B11[6][2] B11[7][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4][3] B11[5][3] B11[6][3] B11[7][3]
|
|
|
|
ymm8 = _mm256_fmsub_pd(ymm8, ymm16, ymm0); //B11[0-3][0] * alpha -= ymm0
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm16, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm16, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm16, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm12 = _mm256_fmsub_pd(ymm12, ymm16, ymm4); //B11[0-3][2] * alpha -= ymm4
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm16, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm16, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm16, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
//1st col
|
|
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm1 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm3 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm2 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm11, ymm2, ymm8);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm15, ymm2, ymm12);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
ymm8 = _mm256_fnmadd_pd(ymm10, ymm3, ymm8);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
ymm12 = _mm256_fnmadd_pd(ymm14, ymm3, ymm12);
|
|
|
|
//(Row 1):FMA operations
|
|
ymm8 = _mm256_fnmadd_pd(ymm9, ymm1, ymm8);
|
|
|
|
ymm12 = _mm256_fnmadd_pd(ymm13, ymm1, ymm12);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm8); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + D_NR), ymm12); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[x][3])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not multiple of D_NR)
|
|
{
|
|
|
|
a01 = L + (j+D_NR)*cs_a +(j); //pointer to block of A to be used in GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used in GEMM
|
|
b11 = B + (i) + (j)*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of GEMM operations to be done(in blocks of 4x4)
|
|
|
|
ymm0 = _mm256_setzero_pd();
|
|
ymm1 = _mm256_setzero_pd();
|
|
ymm2 = _mm256_setzero_pd();
|
|
ymm3 = _mm256_setzero_pd();
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
//load 8x4 block of B11
|
|
if(3 == n_remainder)
|
|
{
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][0]*A01[0][1] B10[5][0]*A01[0][1] B10[6][0]*A01[0][1] B10[7][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][1]*A01[0][1] B10[1][1]*A01[0][1] B10[2][1]*A01[0][1] B10[3][1]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][1]*A01[0][1] B10[5][1]*A01[0][1] B10[6][1]*A01[0][1] B10[7][1]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm12, ymm1); //ymm1 += (B10[0][2]*A01[0][1] B10[1][2]*A01[0][1] B10[2][2]*A01[0][1] B10[3][2]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm13, ymm5); //ymm5 += (B10[4][2]*A01[0][1] B10[5][2]*A01[0][1] B10[6][2]*A01[0][1] B10[7][2]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm1 = _mm256_fmadd_pd(ymm9, ymm14, ymm1); //ymm1 += (B10[0][3]*A01[0][1] B10[1][3]*A01[0][1] B10[2][3]*A01[0][1] B10[3][3]*A01[0][1])
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm9, ymm15, ymm5); //ymm5 += (B10[4][3]*A01[0][1] B10[5][3]*A01[0][1] B10[6][3]*A01[0][1] B10[7][3]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm9 = _mm256_loadu_pd((double const *)(b11+cs_b)); //B11[0-3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b11 + cs_b + D_NR)); //B11[4-7][0]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b*2)); //B11[0-3][1]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b*2 + D_NR)); //B11[4-7][1]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][2]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][2]
|
|
|
|
ymm9 = _mm256_fmsub_pd(ymm9, ymm8, ymm1); //B11[4-7][0] * alpha-= ymm1
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm13 = _mm256_fmsub_pd(ymm13, ymm8, ymm5); //B11[4-7][2] * alpha -= ymm5
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
a11 += 2 * cs_a;
|
|
|
|
//3rd col
|
|
ymm4 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
ymm9 = _mm256_fnmadd_pd(ymm11, ymm5, ymm9);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
ymm13 = _mm256_fnmadd_pd(ymm15, ymm5, ymm13);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm9 = _mm256_fnmadd_pd(ymm10, ymm4, ymm9);
|
|
|
|
ymm13 = _mm256_fnmadd_pd(ymm14, ymm4, ymm13);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm9); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b + D_NR), ymm13); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14);//store(B11[4-7][2])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][0]*A01[0][2] B10[5][0]*A01[0][2] B10[6][0]*A01[0][2] B10[7][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][1]*A01[0][2] B10[1][1]*A01[0][2] B10[2][1]*A01[0][2] B10[3][1]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][1]*A01[0][2] B10[5][1]*A01[0][2] B10[6][1]*A01[0][2] B10[7][1]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm12, ymm2); //ymm2 += (B10[0][2]*A01[0][2] B10[1][2]*A01[0][2] B10[2][2]*A01[0][2] B10[3][2]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm13, ymm6); //ymm6 += (B10[4][2]*A01[0][2] B10[5][2]*A01[0][2] B10[6][2]*A01[0][2] B10[7][2]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm2 = _mm256_fmadd_pd(ymm10, ymm14, ymm2); //ymm2 += (B10[0][3]*A01[0][2] B10[1][3]*A01[0][2] B10[2][3]*A01[0][2] B10[3][3]*A01[0][2])
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm10, ymm15, ymm6); //ymm6 += (B10[4][3]*A01[0][2] B10[5][3]*A01[0][2] B10[6][3]*A01[0][2] B10[7][3]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm10 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0]));
|
|
ymm14 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0] + D_NR)); //B11[4-7][0]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0-3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] + D_NR)); //B11[4-7][1]
|
|
|
|
ymm10 = _mm256_fmsub_pd(ymm10, ymm8, ymm2); //B11[0-3][1] * alpha-= ymm2
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm14 = _mm256_fmsub_pd(ymm14, ymm8, ymm6); //B11[0-3][3] * alpha -= ymm6
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
a11 += 3 * cs_a;
|
|
|
|
//4th col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm10 = _mm256_fnmadd_pd(ymm11, ymm6, ymm10);
|
|
|
|
//(Row 3): FMA operations
|
|
ymm14 = _mm256_fnmadd_pd(ymm15, ymm6, ymm14);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm10); //store(B11[0-3][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0] + D_NR), ymm14); //store(B11[4-7][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//broadcast 1st row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row
|
|
|
|
//load 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + D_NR)); //B10[4][0] B10[5][0] B10[6][0] B10[7][0]
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b + D_NR)); //B10[4][1] B10[5][1] B10[6][1] B10[7][1]
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][0]*A01[0][3] B10[5][0]*A01[0][3] B10[6][0]*A01[0][3] B10[7][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][1]*A01[0][3] B10[1][1]*A01[0][3] B10[2][1]*A01[0][3] B10[3][1]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][1]*A01[0][3] B10[5][1]*A01[0][3] B10[6][1]*A01[0][3] B10[7][1]*A01[0][3])
|
|
|
|
//broadcast 3rd row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
//load next 8x2 block of B10
|
|
ymm12 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //(B10[0][2] B10[1][2] B10[2][2] B10[3][2])
|
|
ymm13 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + D_NR)); //(B10[4][2] B10[5][2] B10[6][2] B10[7][2])
|
|
ymm14 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b)); //(B10[0][3] B10[1][3] B10[2][3] B10[3][3])
|
|
ymm15 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0] + cs_b + D_NR)); //(B10[4][3] B10[5][3] B10[6][3] B10[7][3])
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm12, ymm3); //ymm3 += (B10[0][2]*A01[0][3] B10[1][2]*A01[0][3] B10[2][2]*A01[0][3] B10[3][2]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm13, ymm7); //ymm7 += (B10[4][2]*A01[0][3] B10[5][2]*A01[0][3] B10[6][2]*A01[0][3] B10[7][2]*A01[0][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm11 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A01
|
|
|
|
ymm3 = _mm256_fmadd_pd(ymm11, ymm14, ymm3); //ymm3 += (B10[0][3]*A01[0][3] B10[1][3]*A01[0][3] B10[2][3]*A01[0][3] B10[3][3]*A01[0][3])
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm11, ymm15, ymm7); //ymm7 += (B10[4][3]*A01[0][3] B10[5][3]*A01[0][3] B10[6][3]*A01[0][3] B10[7][3]*A01[0][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code ends///
|
|
ymm8 = _mm256_broadcast_sd((double const *)&AlphaVal);
|
|
|
|
ymm11 = _mm256_loadu_pd((double const *)(b11+cs_b_offset[1])); //B11[0-3][0]
|
|
ymm15 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1] +D_NR)); //B11[4-7][0]
|
|
|
|
ymm11 = _mm256_fmsub_pd(ymm11, ymm8, ymm3); //B11[4-7][1] * alpha -= ymm3
|
|
|
|
ymm15 = _mm256_fmsub_pd(ymm15, ymm8, ymm7); //B11[4-7][3] * alpha -= ymm7
|
|
|
|
_mm256_storeu_pd((double *)(b11+ cs_b_offset[1]), ymm11); //store(B11[0-3][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1] + D_NR), ymm15); //store(B11[4-7][0])
|
|
}
|
|
}
|
|
}
|
|
if(i<0)
|
|
i += D_NR;
|
|
if((m & 4)) ///implementation for remainder rows(when m_remainder is greater than 4)
|
|
{
|
|
for(j = (n-D_NR); (j+1) > 0; j -=D_NR) //loop along n direction
|
|
{
|
|
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
///load 4x4 block of b11
|
|
ymm0 = _mm256_loadu_pd((double const *)b11); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[0])); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b_offset[1])); //B11[0][3] B11[1][3] B11[2][3] B11[3][3]
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[0][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm8, ymm4); //ymm4 += (B10[0][0]*A01[0][0] B10[1][0]*A01[0][0] B10[2][0]*A01[0][0] B10[3][0]*A01[0][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[1][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm9, ymm4); //ymm4 += (B10[0][1]*A01[1][0] B10[1][1]*A01[1][0] B10[2][1]*A01[1][0] B10[3][1]*A01[1][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[2][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm10, ymm4); //ymm4 += (B10[0][2]*A01[2][0] B10[1][2]*A01[2][0] B10[2][2]*A01[2][0] B10[3][2]*A01[2][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a01 + 0)); //A01[3][0]
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm4 = _mm256_fmadd_pd(ymm12, ymm11, ymm4); //ymm4 += (B10[0][3]*A01[3][0] B10[1][3]*A01[3][0] B10[2][3]*A01[3][0] B10[3][3]*A01[3][0])
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + D_NR*cs_a; //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm0 = _mm256_fmsub_pd(ymm0, ymm15, ymm4); //B11[x][0] -=ymm4
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
a11 += cs_a;
|
|
|
|
//2nd col
|
|
ymm5 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//3rd col
|
|
ymm6 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm7 = _mm256_broadcast_sd((double const *)(a11+0)); //A11[0][1]
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm3, ymm7, ymm0);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
ymm0 = _mm256_fnmadd_pd(ymm2, ymm6, ymm0);
|
|
|
|
//(Row 1):FMA operations
|
|
ymm0 = _mm256_fnmadd_pd(ymm1, ymm5, ymm0);
|
|
|
|
_mm256_storeu_pd((double *)b11, ymm0); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[1]), ymm3); //store(B11[x][3])
|
|
|
|
}
|
|
if(n_remainder) //implementation for remainder columns(when n is not a multiple of D_NR)
|
|
{
|
|
|
|
a01 = L + (j+D_NR)*cs_a + (j); //pointer to block of A to be used for GEMM
|
|
a11 = L + j*cs_a + j; //pointer to block of A to be used for TRSM
|
|
b10 = B + i + (j+D_NR)*cs_b; //pointer to block of B to be used for GEMM
|
|
b11 = B + i + j*cs_b; //pointer to block of B to be used for TRSM
|
|
|
|
k_iter = (n-j-D_NR) / D_NR; //number of times GEMM operations to be performed(in blocks of 4x4)
|
|
|
|
///GEMM for previous blocks ///
|
|
|
|
ymm4 = _mm256_setzero_pd();
|
|
ymm5 = _mm256_setzero_pd();
|
|
ymm6 = _mm256_setzero_pd();
|
|
ymm7 = _mm256_setzero_pd();
|
|
|
|
///load 4x4 block of b11
|
|
if(3 == n_remainder)
|
|
{
|
|
ymm1 = _mm256_loadu_pd((double const *)b11+ cs_b); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][2] B11[1][2] B11[2][2] B11[3][2]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[0][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm8, ymm5); //ymm5 += (B10[0][0]*A01[0][1] B10[1][0]*A01[0][1] B10[2][0]*A01[0][1] B10[3][0]*A01[0][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[1][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm9, ymm5); //ymm5 += (B10[0][1]*A01[1][1] B10[1][1]*A01[1][1] B10[2][1]*A01[1][1] B10[3][1]*A01[1][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[2][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm10, ymm5); //ymm5 += (B10[0][2]*A01[2][1] B10[1][2]*A01[2][1] B10[2][2]*A01[2][1] B10[3][2]*A01[2][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm13 = _mm256_broadcast_sd((double const *)(a01 + 1)); //A01[3][1]
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm5 = _mm256_fmadd_pd(ymm13, ymm11, ymm5); //ymm5 += (B10[0][3]*A01[3][1] B10[1][3]*A01[3][1] B10[2][3]*A01[3][1] B10[3][3]*A01[3][1])
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm1 = _mm256_fmsub_pd(ymm1, ymm15, ymm5); //B11[x][1] -= ymm5
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
a11 += 2 * cs_a;
|
|
|
|
//3rd col
|
|
ymm9 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
|
|
a11 += cs_a;
|
|
|
|
//4th col
|
|
ymm10 = _mm256_broadcast_sd((double const *)(a11+1)); //A11[0][1]
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
ymm1 = _mm256_fnmadd_pd(ymm3, ymm10, ymm1);
|
|
|
|
//(ROW 2): FMA operations
|
|
ymm1 = _mm256_fnmadd_pd(ymm2, ymm9, ymm1);
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b), ymm1); //store(B11[x][1])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b_offset[0]), ymm2); //(store(B11[x][2]))
|
|
_mm256_storeu_pd((double *)(b11 + cs_b*3), ymm3); //store(B11[x][0])
|
|
}
|
|
else if(2 == n_remainder)
|
|
{
|
|
ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[0][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm8, ymm6); //ymm6 += (B10[0][0]*A01[0][2] B10[1][0]*A01[0][2] B10[2][0]*A01[0][2] B10[3][0]*A01[0][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[1][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm9, ymm6); //ymm6 += (B10[0][1]*A01[1][2] B10[1][1]*A01[1][2] B10[2][1]*A01[1][2] B10[3][1]*A01[1][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[2][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm10, ymm6); //ymm6 += (B10[0][2]*A01[2][2] B10[1][2]*A01[2][2] B10[2][2]*A01[2][2] B10[3][2]*A01[2][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm14 = _mm256_broadcast_sd((double const *)(a01 + 2)); //A01[3][2]
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm6 = _mm256_fmadd_pd(ymm14, ymm11, ymm6); //ymm6 += (B10[0][3]*A01[3][2] B10[1][3]*A01[3][2] B10[2][3]*A01[3][2] B10[3][3]*A01[3][2])
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm2 = _mm256_fmsub_pd(ymm2, ymm15, ymm6); //B11[x][2] -= ymm6
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
///implement TRSM///
|
|
|
|
///read 4x4 block of A11///
|
|
|
|
|
|
a11 += 3 * cs_a;
|
|
|
|
//4th col
|
|
ymm12 = _mm256_broadcast_sd((double const *)(a11+2)); //A11[0][1]
|
|
|
|
//(Row 3): FMA operations
|
|
ymm2 = _mm256_fnmadd_pd(ymm3, ymm12, ymm2);
|
|
|
|
_mm256_storeu_pd((double *)(b11+ cs_b * 2), ymm2); //store(B11[x][0])
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][1])
|
|
}
|
|
else if(1 == n_remainder)
|
|
{
|
|
ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); //B11[0][0] B11[1][0] B11[2][0] B11[3][0]
|
|
|
|
///GEMM implementation starts///
|
|
|
|
for(k = 0; k < k_iter; k++) //loop for number of GEMM operations
|
|
{
|
|
ptr_a01_dup = a01;
|
|
|
|
//load 4x4 bblock of b10
|
|
ymm8 = _mm256_loadu_pd((double const *)b10); //B10[0][0] B10[1][0] B10[2][0] B10[3][0]
|
|
ymm9 = _mm256_loadu_pd((double const *)(b10 + cs_b)); //B10[0][1] B10[1][1] B10[2][1] B10[3][1]
|
|
ymm10 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[0])); //B10[0][2] B10[1][2] B10[2][2] B10[3][2]
|
|
ymm11 = _mm256_loadu_pd((double const *)(b10 + cs_b_offset[1])); //B10[0][3] B10[1][3] B10[2][3] B10[3][3]
|
|
|
|
//broadcast 1st row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[0][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm8, ymm7); //ymm7 += (B10[0][0]*A01[0][3] B10[1][0]*A01[0][3] B10[2][0]*A01[0][3] B10[3][0]*A01[0][3])
|
|
|
|
//broadcast 2nd row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[1][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm9, ymm7); //ymm7 += (B10[0][1]*A01[1][3] B10[1][1]*A01[1][3] B10[2][1]*A01[1][3] B10[3][1]*A01[1][3])
|
|
|
|
//braodcast 3rd row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[2][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm10, ymm7); //ymm7 += (B10[0][2]*A01[2][3] B10[1][2]*A01[2][3] B10[2][2]*A01[2][3] B10[3][2]*A01[2][3])
|
|
|
|
//broadcast 4th row of A01
|
|
ymm15 = _mm256_broadcast_sd((double const *)(a01 + 3)); //A01[3][3]
|
|
|
|
a01 += cs_a; //move to next row of A
|
|
|
|
ymm7 = _mm256_fmadd_pd(ymm15, ymm11, ymm7); //ymm7 += (B10[0][3]*A01[3][3] B10[1][3]*A01[3][3] B10[2][3]*A01[3][3] B10[3][3]*A01[3][3])
|
|
|
|
b10 += D_NR * cs_b; //pointer math to find next block of B for GEMM
|
|
a01 = ptr_a01_dup + (D_NR * cs_a); //pointer math to find next block of A for GEMM
|
|
}
|
|
|
|
///GEMM code end///
|
|
ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); //register to store alpha
|
|
|
|
ymm3 = _mm256_fmsub_pd(ymm3, ymm15, ymm7); //B11[x][3] -= ymm7
|
|
|
|
_mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[x][0])
|
|
}
|
|
}
|
|
m_remainder -= 4;
|
|
i -= 4;
|
|
}
|
|
if(m_remainder) ///implementation for remainder rows
|
|
{
|
|
dtrsm_small_XAutB_unitDiag(L, B, AlphaVal, m_remainder, n, cs_a, cs_b);
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
/*
|
|
* AX = Alpha*B, Single precision, A: lower triangular
|
|
* This kernel implementation supports matrices A and B such that m is equal to BLI_AlXB_M_SP and n is mutiple of 8
|
|
*/
|
|
|
|
static err_t bli_strsm_small_AlXB (
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
obj_t alpha, beta; // gemm parameters
|
|
obj_t Ga, Gb, Gc; // for GEMM
|
|
int m = bli_obj_length(b); // number of rows of matrix B
|
|
int n = bli_obj_width(b); // number of columns of matrix B
|
|
|
|
int lda = bli_obj_col_stride(a); // column stride of A
|
|
int ldb = bli_obj_col_stride(b); // column stride of B
|
|
|
|
int rsa = bli_obj_row_stride(a); // row stride of A
|
|
int rsb = bli_obj_row_stride(b); // row stride of B
|
|
|
|
int i = 0;
|
|
int j;
|
|
int blk_size = 8;
|
|
int isUnitDiag = bli_obj_has_unit_diag(a);
|
|
|
|
float alphaVal;
|
|
float* restrict L = a->buffer;
|
|
float* restrict B = b->buffer;
|
|
|
|
if (m != BLI_AlXB_M_SP || (n&7) != 0)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
if ( (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM )
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj));
|
|
|
|
/* Small _GEMM preparation code */
|
|
bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &alpha );
|
|
bli_obj_create( BLIS_FLOAT, 1, 1, 0, 0, &beta );
|
|
|
|
/* B = B - A*B */
|
|
bli_setsc( -(1.0), 0.0, &alpha );
|
|
bli_setsc( (1.0), 0.0, &beta );
|
|
|
|
|
|
bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, blk_size, a->buffer, rsa, lda, &Ga);
|
|
bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gb);
|
|
bli_obj_create_with_attached_buffer( BLIS_FLOAT, blk_size, n, b->buffer, rsb, ldb, &Gc);
|
|
|
|
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Ga );
|
|
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gb );
|
|
bli_obj_set_conjtrans( BLIS_NO_TRANSPOSE, &Gc );
|
|
|
|
//first block of trsm
|
|
Gb.buffer = (void*)(B + i);
|
|
|
|
//trsm of first 8xn block
|
|
if (alphaVal != 1)
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
blis_strsm_microkernel_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
fp_blis_strsm_microkernel = blis_strsm_microkernel;
|
|
}
|
|
else
|
|
{
|
|
blis_strsm_microkernel_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag;
|
|
}
|
|
bli_setsc( alphaVal, 0.0, &beta );
|
|
}
|
|
else
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
fp_blis_strsm_microkernel = blis_strsm_microkernel;
|
|
}
|
|
else
|
|
{
|
|
blis_strsm_microkernel_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
fp_blis_strsm_microkernel = blis_strsm_microkernel_unitDiag;
|
|
}
|
|
}
|
|
|
|
//gemm update
|
|
for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT
|
|
{
|
|
Ga.buffer = (void*)(L + j + i*lda);
|
|
Gc.buffer = (void*)(B + j);
|
|
|
|
bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb
|
|
}
|
|
|
|
//trsm of remaining blocks
|
|
for (i = blk_size; i < m; i += blk_size)
|
|
{
|
|
Gb.buffer = (void*)(B + i);
|
|
|
|
fp_blis_strsm_microkernel((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
|
|
for (j = i + blk_size; j < m; j += blk_size) // for rows upto multiple of BLOCK_HEIGHT
|
|
{
|
|
Ga.buffer = (void*)(L + j + i*lda);
|
|
Gc.buffer = (void*)(B + j);
|
|
|
|
bli_gemm_small(&alpha, &Ga, &Gb, &beta, &Gc, cntx, cntl ); // Gc = beta*Gc + alpha*Ga *Gb
|
|
}
|
|
|
|
} // End of for loop - i
|
|
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
|
|
|
|
/*
|
|
* XA' = Alpha*B, Single precision, A: lower triangular
|
|
* This kernel implementation supports matrices A and B such that
|
|
* m and n are multiples of 8 and n is less than or equal to BLI_XAltB_N_SP
|
|
*/
|
|
static err_t bli_strsm_small_XAltB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
int m = bli_obj_length(a); // number of rows of matrix B
|
|
int n = bli_obj_length(b); // number of columns of matrix B
|
|
|
|
int lda = bli_obj_col_stride(a); // column stride of A
|
|
int ldb = bli_obj_col_stride(b); // column stride of B
|
|
|
|
int rsa = bli_obj_row_stride(a); // row stride of A
|
|
int rsb = bli_obj_row_stride(b); // row stride of B
|
|
|
|
int i = 0;
|
|
int isUnitDiag = bli_obj_has_unit_diag(a);
|
|
|
|
float alphaVal;
|
|
float *L = a->buffer;
|
|
float *B = b->buffer;
|
|
|
|
if ((m&7) != 0 || (n&7) != 0)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
if ( n > BLI_XAltB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM )
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj));
|
|
|
|
if (alphaVal != 1)
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
trsm_XAtB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
}
|
|
else
|
|
{
|
|
trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
trsm_XAtB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
}
|
|
else
|
|
{
|
|
trsm_XAtB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
/*
|
|
* A'X = Alpha*B, Single precision, A: upper triangular
|
|
* This kernel implementation supports matrices A and B such that
|
|
* m and n are multiples of 8, m is less than or equal to BLI_AutXB_M_SP and n is less than or equal to BLI_AutXB_N_SP
|
|
*/
|
|
static err_t bli_strsm_small_AutXB(
|
|
side_t side,
|
|
obj_t* AlphaObj,
|
|
obj_t* a,
|
|
obj_t* b,
|
|
cntx_t* cntx,
|
|
cntl_t* cntl
|
|
)
|
|
{
|
|
int m = bli_obj_width(a); // number of rows of matrix A (since At, so width is taken)
|
|
int n = bli_obj_width(b); // number of columns of matrix B
|
|
|
|
int lda = bli_obj_col_stride(a); // column stride of A
|
|
int ldb = bli_obj_col_stride(b); // column stride of B
|
|
|
|
int rsa = bli_obj_row_stride(a); // row stride of A
|
|
int rsb = bli_obj_row_stride(b); // row stride of B
|
|
|
|
int i = 0;
|
|
int isUnitDiag = bli_obj_has_unit_diag(a);
|
|
|
|
float alphaVal;
|
|
float *L = a->buffer;
|
|
float *B = b->buffer;
|
|
|
|
if ((m&7) != 0 || (n&7) != 0)
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
if ( m > BLI_AutXB_M_SP || n > BLI_AutXB_N_SP || (m*(m + n)) > BLIS_SMALL_MATRIX_THRES_TRSM )
|
|
{
|
|
return BLIS_NOT_YET_IMPLEMENTED;
|
|
}
|
|
|
|
alphaVal = *((float *)bli_obj_buffer_for_const(BLIS_FLOAT, AlphaObj));
|
|
|
|
if (alphaVal != 1)
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
trsm_AutXB_block_allSmallSizedMatrices_alpha((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
}
|
|
else
|
|
{
|
|
trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb, alphaVal);
|
|
}
|
|
}
|
|
else
|
|
{
|
|
if (isUnitDiag == 0)
|
|
{
|
|
trsm_AutXB_block_allSmallSizedMatrices((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
}
|
|
else
|
|
{
|
|
trsm_AutXB_block_allSmallSizedMatrices_unitDiag((L + i * lda + i), (B + i), m, n, rsa, rsb, lda, ldb);
|
|
}
|
|
}
|
|
return BLIS_SUCCESS;
|
|
}
|
|
|
|
///////////////////////////// AX=B ///////////////////////////////
|
|
static void blis_strsm_microkernel_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal)
|
|
{
|
|
float ones = 1.0;
|
|
int j;
|
|
int cs_b_offset[6];
|
|
//int row2, row4, row6;
|
|
float *ptr_b_dup;
|
|
|
|
//70 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_cols[8];
|
|
__m256 mat_a_cols_rearr[36];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags;
|
|
__m256 alphaReg;
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alphaVal);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
|
|
//row2 = (cs_l << 1);
|
|
//row4 = (cs_l << 2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
|
|
//row6 = row2 + row4;
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
|
|
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
|
|
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
|
|
|
|
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
|
|
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
|
|
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
|
|
//1st col
|
|
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
|
|
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
|
|
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
|
|
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
|
|
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
|
|
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
|
|
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
|
|
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
|
|
//2nd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//3rd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//4rth col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//5th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//6th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
numCols_b -= 8; // blk_width = 8
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
|
|
mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
|
|
mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
|
|
mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
|
|
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
|
|
|
|
//reciprocal of diagnol elements
|
|
reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
|
|
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Read next set of B columns
|
|
ptr_b += (cs_b + cs_b_offset[5]);
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
|
|
//end loop of cols
|
|
}
|
|
|
|
//Last block trsm processing
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
|
|
//end loop of cols
|
|
}
|
|
|
|
static void blis_strsm_microkernel_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alphaVal)
|
|
{
|
|
//float ones = 1.0;
|
|
int j;
|
|
int cs_b_offset[6];
|
|
//int row2, row4, row6;
|
|
float *ptr_b_dup;
|
|
|
|
//70 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_cols[8];
|
|
__m256 mat_a_cols_rearr[36];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags;
|
|
__m256 alphaReg;
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
//reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alphaVal);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
|
|
//row2 = (cs_l << 1);
|
|
//row4 = (cs_l << 2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
|
|
//row6 = row2 + row4;
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
|
|
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
|
|
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
|
|
|
|
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
|
|
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
|
|
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
|
|
//1st col
|
|
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
|
|
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
|
|
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
|
|
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
|
|
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
|
|
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
|
|
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
|
|
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
|
|
//2nd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//3rd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//4rth col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//5th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//6th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//8th col
|
|
//ptr_l += cs_l;
|
|
//mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
numCols_b -= 8; // blk_width = 8
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
|
|
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
|
|
//mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
|
|
//mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
|
|
|
|
//reciprocal of diagnol elements
|
|
//reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
|
|
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
|
|
//extract diag a11 from a
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Read next set of B columns
|
|
ptr_b += (cs_b + cs_b_offset[5]);
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
|
|
//end loop of cols
|
|
}
|
|
|
|
//Last block trsm processing
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
|
|
//extract diag a11 from a
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
|
|
//end loop of cols
|
|
}
|
|
|
|
static void blis_strsm_microkernel_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
//float ones = 1.0;
|
|
int j;
|
|
int cs_b_offset[6];
|
|
//int row2, row4, row6;
|
|
float *ptr_b_dup;
|
|
|
|
//70 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_cols[8];
|
|
__m256 mat_a_cols_rearr[36];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags;
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
//reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
|
|
//row2 = (cs_l << 1);
|
|
//row4 = (cs_l << 2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
|
|
//row6 = row2 + row4;
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
|
|
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
|
|
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
|
|
|
|
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
|
|
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
|
|
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
|
|
//1st col
|
|
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
|
|
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
|
|
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
|
|
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
|
|
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
|
|
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
|
|
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
|
|
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
|
|
//2nd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//3rd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//4rth col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//5th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//6th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//8th col
|
|
//ptr_l += cs_l;
|
|
//mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
numCols_b -= 8; // blk_width = 8
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
|
|
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
|
|
//mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
|
|
//mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
|
|
|
|
//reciprocal of diagnol elements
|
|
//reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
|
|
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
//extract diag a11 from a
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Read next set of B columns
|
|
ptr_b += (cs_b + cs_b_offset[5]);
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
//end loop of cols
|
|
}
|
|
|
|
//Last block trsm processing
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
//mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
//mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
//extract diag a11 from a
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
//mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
//mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
//mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
//mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
//mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
//mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
//mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
//mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
//mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
//mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
//mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
//mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
//mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
//mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
//mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
//mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
//end loop of cols
|
|
}
|
|
|
|
static void blis_strsm_microkernel(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
float ones = 1.0;
|
|
int j;
|
|
int cs_b_offset[6];
|
|
//int row2, row4, row6;
|
|
float *ptr_b_dup;
|
|
|
|
//70 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_cols[8];
|
|
__m256 mat_a_cols_rearr[36];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags;
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
reciprocal_diags = _mm256_broadcast_ss((float const *)&ones);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//read first set of 16x8 block of B into registers, where 16 is the blk_height and 8 is the blk_width for B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
//_mm_prefetch((char*)(ptr_l + 0), _MM_HINT_T0);
|
|
//row2 = (cs_l << 1);
|
|
//row4 = (cs_l << 2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
//_mm_prefetch((char*)(ptr_l + cs_l), _MM_HINT_T0);
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
//_mm_prefetch((char*)(ptr_l + row2), _MM_HINT_T0);
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
//_mm_prefetch((char*)(ptr_l + row2 + cs_l), _MM_HINT_T0);
|
|
//row6 = row2 + row4;
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
//_mm_prefetch((char*)(ptr_l + row4), _MM_HINT_T0);
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
//_mm_prefetch((char*)(ptr_l + row4 + cs_l), _MM_HINT_T0);
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
//_mm_prefetch((char*)(ptr_l + row6), _MM_HINT_T0);
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
//_mm_prefetch((char*)(ptr_l + row6 + cs_l), _MM_HINT_T0);
|
|
|
|
//reciprocal_diags = _mm256_loadu_ps((float const *)ones);
|
|
|
|
//read first set of 16x16 block of L, where 16 is the blk_height and 16 is the blk_width for L
|
|
/*mat_a_cols[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[1] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[2] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[3] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[4] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[5] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[6] = _mm256_loadu_ps((float const *)ptr_l);
|
|
ptr_l += cs_l;
|
|
mat_a_cols[7] = _mm256_loadu_ps((float const *)ptr_l);*/
|
|
|
|
//Shuffle to rearrange/transpose 16x16 block of L into contiguous row-wise registers
|
|
//tmpRegs[0] = _mm256_castps256_ps128(mat_a_cols[0]); //zero latency, no instruction added actually.
|
|
//mat_a_cols_rearr[0] = _mm256_broadcastss_ps(tmpRegs[0]);
|
|
//1st col
|
|
mat_a_cols_rearr[0] = _mm256_broadcast_ss((float const *)(ptr_l+0));
|
|
mat_a_cols_rearr[1] = _mm256_broadcast_ss((float const *)(ptr_l+1));
|
|
mat_a_cols_rearr[3] = _mm256_broadcast_ss((float const *)(ptr_l+2));
|
|
mat_a_cols_rearr[6] = _mm256_broadcast_ss((float const *)(ptr_l+3));
|
|
mat_a_cols_rearr[10] = _mm256_broadcast_ss((float const *)(ptr_l+4));
|
|
mat_a_cols_rearr[15] = _mm256_broadcast_ss((float const *)(ptr_l+5));
|
|
mat_a_cols_rearr[21] = _mm256_broadcast_ss((float const *)(ptr_l+6));
|
|
mat_a_cols_rearr[28] = _mm256_broadcast_ss((float const *)(ptr_l+7));
|
|
//2nd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_cols_rearr[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[7] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[11] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[16] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[22] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[29] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//3rd col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[5] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_cols_rearr[8] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[12] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[17] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[23] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[30] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//4rth col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[9] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_cols_rearr[13] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[18] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[24] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[31] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//5th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[14] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_cols_rearr[19] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[25] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[32] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//6th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[20] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_cols_rearr[26] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[33] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[27] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_cols_rearr[34] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
//7th col
|
|
ptr_l += cs_l;
|
|
mat_a_cols_rearr[35] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
numCols_b -= 8; // blk_width = 8
|
|
|
|
//compute reciprocals of L(i,i) and broadcast in registers
|
|
mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[2]);
|
|
mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_cols_rearr[5], mat_a_cols_rearr[9]);
|
|
mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_cols_rearr[14], mat_a_cols_rearr[20]);
|
|
mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_cols_rearr[27], mat_a_cols_rearr[35]);
|
|
|
|
//mat_a_diag_inv[1] = _mm256_permute_ps(mat_a_diag_inv[1], 0x55);
|
|
//mat_a_diag_inv[3] = _mm256_permute_ps(mat_a_diag_inv[3], 0x55);
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0x20);
|
|
|
|
//reciprocal of diagnol elements
|
|
reciprocal_diags = _mm256_div_ps(reciprocal_diags, mat_a_diag_inv[0]);
|
|
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Read next set of B columns
|
|
ptr_b += (cs_b + cs_b_offset[5]);
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + (cs_b)));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5]));
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
//end loop of cols
|
|
}
|
|
|
|
//Last block trsm processing
|
|
ptr_b_dup = ptr_b;
|
|
|
|
/*Shuffle to rearrange/transpose 16x8 block of B into contiguous row-wise registers*/
|
|
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_cols_rearr[1], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[3], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[6], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[10], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[15], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[21], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[28], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_cols_rearr[4], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[7], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[11], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[16], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[22], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[29], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_cols_rearr[8], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[12], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[17], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[23], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[30], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags, 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_cols_rearr[13], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[18], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[24], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[31], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags, 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_cols_rearr[19], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[25], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[32], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags, 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_cols_rearr[26], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[33], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags, 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_cols_rearr[34], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
//--> Transpose and store results of columns of B block <--//
|
|
////unpacklow////
|
|
mat_a_cols[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_a_cols[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_a_cols[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_a_cols[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_cols[4] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x44);
|
|
mat_a_cols[5] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0xEE);
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x44);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0xEE);
|
|
#else
|
|
mat_a_cols[6] = _mm256_shuffle_ps(mat_a_cols[0], mat_a_cols[1], 0x4E);
|
|
mat_a_cols[7] = _mm256_shuffle_ps(mat_a_cols[2], mat_a_cols[3], 0x4E);
|
|
mat_a_cols[4] = _mm256_blend_ps(mat_a_cols[0], mat_a_cols[6], 0xCC);
|
|
mat_a_cols[5] = _mm256_blend_ps(mat_a_cols[1], mat_a_cols[6], 0x33);
|
|
mat_a_cols[6] = _mm256_blend_ps(mat_a_cols[2], mat_a_cols[7], 0xCC);
|
|
mat_a_cols[7] = _mm256_blend_ps(mat_a_cols[3], mat_a_cols[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_a_cols[0] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x20);
|
|
mat_a_cols[4] = _mm256_permute2f128_ps(mat_a_cols[4], mat_a_cols[6], 0x31);
|
|
mat_a_cols[1] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x20);
|
|
mat_a_cols[5] = _mm256_permute2f128_ps(mat_a_cols[5], mat_a_cols[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_a_cols[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_a_cols[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_a_cols[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_a_cols[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_a_cols[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_a_cols[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_a_cols[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_a_cols[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_a_cols[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_a_cols[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_a_cols[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_a_cols[7]);
|
|
//end loop of cols
|
|
}
|
|
|
|
#if OPT_CACHE_BLOCKING_L1 //new intrinsic kernels
|
|
static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += cs_l_offset[6];
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
i = i1 + r;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
#endif
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2 + r;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
i4 = k >> 3;
|
|
ptr_l_dup += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
i += cs_l_offset[6];
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
{
|
|
i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += cs_l_offset[6];
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
i = i1 + r;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
#endif
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2 + r;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
i4 = k >> 3;
|
|
ptr_l_dup += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
i += cs_l_offset[6];
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
{
|
|
i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
//(Row0)
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
i3 += cs_l_offset[6];
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
i = i1 + r;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
#endif
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2 + r;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
i4 = k >> 3;
|
|
ptr_l_dup += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
i += cs_l_offset[6];
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
{
|
|
i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): already done
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
|
|
//(Row0)
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
i3 += cs_l_offset[6];
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
i = i1 + r;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
#endif
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2 + r;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
i4 = k >> 3;
|
|
ptr_l_dup += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + i + 7));
|
|
ptr_l_dup += cs_l;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
i += cs_l_offset[6];
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
{
|
|
i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i2);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i2));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i2));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i2));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i2));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i2));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i2));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i2));
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): already done
|
|
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + r, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+r), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + r), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + r), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + r), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + r), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + r), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + r), mat_b_rearr[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
#else //rel 1.0 intrisic kernels (NOT OPT_CACHE_BLOCKING_L1)
|
|
static void trsm_XAtB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[16][8];
|
|
__m256 mat_a_cols_rearr[8];
|
|
__m256 mat_a_blk_elems[64];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += cs_l_offset[6];
|
|
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (k = 0; k < numCols_b; k += 8)
|
|
{
|
|
i = i1 + k;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
i2++;
|
|
}
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
|
|
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
|
|
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
|
|
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
|
|
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
|
|
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
|
|
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
|
|
|
|
// _mm256_permute2f128_ps()
|
|
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
|
|
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
|
|
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
|
|
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
|
|
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
|
|
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
|
|
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
|
|
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
|
|
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
|
|
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
|
|
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
|
|
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
|
|
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
|
|
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
|
|
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
|
|
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
|
|
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
|
|
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
|
|
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
|
|
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
|
|
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
|
|
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
|
|
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
|
|
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
|
|
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
|
|
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
|
|
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
|
|
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
|
|
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
|
|
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
|
|
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
|
|
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
|
|
|
|
i += cs_l_offset[6];
|
|
|
|
|
|
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
|
|
i4 = i2 + k;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
i4 = k >> 3;
|
|
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
k = 0;
|
|
for (i = 0; i < numCols_b; i+=8)
|
|
{
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
|
|
|
|
}
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[16][8];
|
|
__m256 mat_a_cols_rearr[8];
|
|
__m256 mat_a_blk_elems[64];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg);
|
|
mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg);
|
|
mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg);
|
|
mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg);
|
|
mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg);
|
|
mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg);
|
|
mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg);
|
|
mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_rearr[0][0], mat_a_diag_inv[0]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_rearr[1][0], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_rearr[2][0], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_rearr[3][0], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_rearr[4][0], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_rearr[5][0], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_rearr[6][0], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_rearr[7][0], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += cs_l_offset[6];
|
|
mat_a_cols_rearr[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_cols_rearr[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_cols_rearr[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_cols_rearr[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_cols_rearr[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_cols_rearr[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_cols_rearr[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_cols_rearr[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_cols_rearr[0], mat_a_cols_rearr[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_cols_rearr[2], mat_a_cols_rearr[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_cols_rearr[4], mat_a_cols_rearr[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_cols_rearr[6], mat_a_cols_rearr[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (k = 0; k < numCols_b; k += 8)
|
|
{
|
|
i = i1 + k;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg);
|
|
mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg);
|
|
mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg);
|
|
mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg);
|
|
mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg);
|
|
mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg);
|
|
mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg);
|
|
mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg);
|
|
|
|
i2++;
|
|
}
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
|
|
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
|
|
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
|
|
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
|
|
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
|
|
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
|
|
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
|
|
|
|
// _mm256_permute2f128_ps()
|
|
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
|
|
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
|
|
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
|
|
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
|
|
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
|
|
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
|
|
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
|
|
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
|
|
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
|
|
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
|
|
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
|
|
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
|
|
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
|
|
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
|
|
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
|
|
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
|
|
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
|
|
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
|
|
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
|
|
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
|
|
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
|
|
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
|
|
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
|
|
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
|
|
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
|
|
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
|
|
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
|
|
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
|
|
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
|
|
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
|
|
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
|
|
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
|
|
|
|
i += cs_l_offset[6];
|
|
|
|
|
|
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
|
|
i4 = i2 + k;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
i4 = k >> 3;
|
|
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
k = 0;
|
|
for (i = 0; i < numCols_b; i+=8)
|
|
{
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[k][0] = _mm256_mul_ps(mat_b_rearr[k][0], mat_a_diag_inv[0]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[k][1] = _mm256_mul_ps(mat_b_rearr[k][1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[k][2] = _mm256_mul_ps(mat_b_rearr[k][2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[k][3] = _mm256_mul_ps(mat_b_rearr[k][3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[k][4] = _mm256_mul_ps(mat_b_rearr[k][4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[k][5] = _mm256_mul_ps(mat_b_rearr[k][5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[k][6] = _mm256_mul_ps(mat_b_rearr[k][6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[k][7] = _mm256_mul_ps(mat_b_rearr[k][7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
|
|
k++;
|
|
}
|
|
|
|
|
|
}
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[16][8];
|
|
//__m256 mat_a_cols_rearr[8];
|
|
__m256 mat_a_blk_elems[64];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
//(Row0)
|
|
mat_b_col[0] = mat_b_rearr[0][0];
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
i3 += cs_l_offset[6];
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (k = 0; k < numCols_b; k += 8)
|
|
{
|
|
i = i1 + k;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
i2++;
|
|
}
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
|
|
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
|
|
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
|
|
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
|
|
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
|
|
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
|
|
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
|
|
|
|
// _mm256_permute2f128_ps()
|
|
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
|
|
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
|
|
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
|
|
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
|
|
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
|
|
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
|
|
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
|
|
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
|
|
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
|
|
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
|
|
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
|
|
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
|
|
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
|
|
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
|
|
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
|
|
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
|
|
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
|
|
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
|
|
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
|
|
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
|
|
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
|
|
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
|
|
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
|
|
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
|
|
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
|
|
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
|
|
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
|
|
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
|
|
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
|
|
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
|
|
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
|
|
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
|
|
|
|
i += cs_l_offset[6];
|
|
|
|
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
|
|
i4 = i2 + k;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
i4 = k >> 3;
|
|
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
k = 0;
|
|
for (i = 0; i < numCols_b; i+=8)
|
|
{
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
|
|
//(Row0): already done
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
|
|
|
|
}
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_XAtB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[16][8];
|
|
//__m256 mat_a_cols_rearr[8];
|
|
__m256 mat_a_blk_elems[64];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7][0] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[0][0] = _mm256_mul_ps(mat_b_rearr[0][0], alphaReg);
|
|
mat_b_rearr[1][0] = _mm256_mul_ps(mat_b_rearr[1][0], alphaReg);
|
|
mat_b_rearr[2][0] = _mm256_mul_ps(mat_b_rearr[2][0], alphaReg);
|
|
mat_b_rearr[3][0] = _mm256_mul_ps(mat_b_rearr[3][0], alphaReg);
|
|
mat_b_rearr[4][0] = _mm256_mul_ps(mat_b_rearr[4][0], alphaReg);
|
|
mat_b_rearr[5][0] = _mm256_mul_ps(mat_b_rearr[5][0], alphaReg);
|
|
mat_b_rearr[6][0] = _mm256_mul_ps(mat_b_rearr[6][0], alphaReg);
|
|
mat_b_rearr[7][0] = _mm256_mul_ps(mat_b_rearr[7][0], alphaReg);
|
|
|
|
//(Row0)
|
|
mat_b_col[0] = mat_b_rearr[0][0];
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[1][0]);//d = c - (a*b)
|
|
mat_b_rearr[2][0] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[2][0]);//d = c - (a*b)
|
|
mat_b_rearr[3][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[2], mat_b_rearr[3][0]);//d = c - (a*b)
|
|
mat_b_rearr[4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[2], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[2], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[3], mat_b_rearr[4][0]);//d = c - (a*b)
|
|
mat_b_rearr[5][0] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[3], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[3], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[3], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[4], mat_b_rearr[5][0]);//d = c - (a*b)
|
|
mat_b_rearr[6][0] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[4], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[4], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[5], mat_b_rearr[6][0]);//d = c - (a*b)
|
|
mat_b_rearr[7][0] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[5], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[6], mat_b_rearr[7][0]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_col[7]);
|
|
|
|
//i += cs_b_offset[6];
|
|
//ptr_b_dup += cs_b_offset[6];
|
|
i += 8;
|
|
ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += 8;
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += cs_b_offset[6];
|
|
i1 += cs_b_offset[6];
|
|
i3 += cs_l_offset[6];
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (k = 0; k < numCols_b; k += 8)
|
|
{
|
|
i = i1 + k;
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[i2][0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[i2][1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[i2][2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[i2][3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[i2][4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[i2][5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[i2][6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[i2][7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
mat_b_rearr[i2][0] = _mm256_mul_ps(mat_b_rearr[i2][0], alphaReg);
|
|
mat_b_rearr[i2][1] = _mm256_mul_ps(mat_b_rearr[i2][1], alphaReg);
|
|
mat_b_rearr[i2][2] = _mm256_mul_ps(mat_b_rearr[i2][2], alphaReg);
|
|
mat_b_rearr[i2][3] = _mm256_mul_ps(mat_b_rearr[i2][3], alphaReg);
|
|
mat_b_rearr[i2][4] = _mm256_mul_ps(mat_b_rearr[i2][4], alphaReg);
|
|
mat_b_rearr[i2][5] = _mm256_mul_ps(mat_b_rearr[i2][5], alphaReg);
|
|
mat_b_rearr[i2][6] = _mm256_mul_ps(mat_b_rearr[i2][6], alphaReg);
|
|
mat_b_rearr[i2][7] = _mm256_mul_ps(mat_b_rearr[i2][7], alphaReg);
|
|
|
|
i2++;
|
|
}
|
|
|
|
i = 0;
|
|
i2 = 0;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 1));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 2));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 3));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 4));
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 5));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 6));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + i + 7));
|
|
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 1));
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 2));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 3));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 4));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 5));
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 6));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + i + 7));
|
|
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i));
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 1));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 2));
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 3));
|
|
mat_a_blk_elems[28] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 4));
|
|
mat_a_blk_elems[29] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 5));
|
|
mat_a_blk_elems[30] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 6));
|
|
mat_a_blk_elems[31] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + i + 7));
|
|
|
|
// _mm256_permute2f128_ps()
|
|
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[32] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i));
|
|
mat_a_blk_elems[33] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 1));
|
|
mat_a_blk_elems[34] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 2));
|
|
mat_a_blk_elems[35] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 3));
|
|
mat_a_blk_elems[36] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 4));
|
|
mat_a_blk_elems[37] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 5));
|
|
mat_a_blk_elems[38] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 6));
|
|
mat_a_blk_elems[39] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + i + 7));
|
|
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[40] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i));
|
|
mat_a_blk_elems[41] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 1));
|
|
mat_a_blk_elems[42] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 2));
|
|
mat_a_blk_elems[43] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 3));
|
|
mat_a_blk_elems[44] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 4));
|
|
mat_a_blk_elems[45] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 5));
|
|
mat_a_blk_elems[46] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 6));
|
|
mat_a_blk_elems[47] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + i + 7));
|
|
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[48] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i));
|
|
mat_a_blk_elems[49] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 1));
|
|
mat_a_blk_elems[50] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 2));
|
|
mat_a_blk_elems[51] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 3));
|
|
mat_a_blk_elems[52] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 4));
|
|
mat_a_blk_elems[53] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 5));
|
|
mat_a_blk_elems[54] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 6));
|
|
mat_a_blk_elems[55] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + i + 7));
|
|
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[56] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i));
|
|
mat_a_blk_elems[57] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 1));
|
|
mat_a_blk_elems[58] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 2));
|
|
mat_a_blk_elems[59] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 3));
|
|
mat_a_blk_elems[60] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 4));
|
|
mat_a_blk_elems[61] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 5));
|
|
mat_a_blk_elems[62] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 6));
|
|
mat_a_blk_elems[63] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5] + i + 7));
|
|
|
|
i += cs_l_offset[6];
|
|
|
|
for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
|
|
i4 = i2 + k;
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
i4 = k >> 3;
|
|
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_col[1], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_col[1], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_col[1], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_col[1], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_col[1], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_col[1], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_col[1], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_col[1], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_col[2], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_col[2], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_col[2], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_col[2], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_col[2], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_col[2], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_col[2], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_col[2], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_col[3], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_col[3], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_col[3], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_col[3], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[28], mat_b_col[3], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[29], mat_b_col[3], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[30], mat_b_col[3], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[31], mat_b_col[3], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[32], mat_b_col[4], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[33], mat_b_col[4], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[34], mat_b_col[4], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[35], mat_b_col[4], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[36], mat_b_col[4], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[37], mat_b_col[4], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[38], mat_b_col[4], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[39], mat_b_col[4], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[40], mat_b_col[5], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[41], mat_b_col[5], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[42], mat_b_col[5], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[43], mat_b_col[5], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[44], mat_b_col[5], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[45], mat_b_col[5], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[46], mat_b_col[5], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[47], mat_b_col[5], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[48], mat_b_col[6], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[49], mat_b_col[6], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[50], mat_b_col[6], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[51], mat_b_col[6], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[52], mat_b_col[6], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[53], mat_b_col[6], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[54], mat_b_col[6], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[55], mat_b_col[6], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[i4][0] = _mm256_fnmadd_ps(mat_a_blk_elems[56], mat_b_col[7], mat_b_rearr[i4][0]);//d = c - (a*b)
|
|
mat_b_rearr[i4][1] = _mm256_fnmadd_ps(mat_a_blk_elems[57], mat_b_col[7], mat_b_rearr[i4][1]);//d = c - (a*b)
|
|
mat_b_rearr[i4][2] = _mm256_fnmadd_ps(mat_a_blk_elems[58], mat_b_col[7], mat_b_rearr[i4][2]);//d = c - (a*b)
|
|
mat_b_rearr[i4][3] = _mm256_fnmadd_ps(mat_a_blk_elems[59], mat_b_col[7], mat_b_rearr[i4][3]);//d = c - (a*b)
|
|
mat_b_rearr[i4][4] = _mm256_fnmadd_ps(mat_a_blk_elems[60], mat_b_col[7], mat_b_rearr[i4][4]);//d = c - (a*b)
|
|
mat_b_rearr[i4][5] = _mm256_fnmadd_ps(mat_a_blk_elems[61], mat_b_col[7], mat_b_rearr[i4][5]);//d = c - (a*b)
|
|
mat_b_rearr[i4][6] = _mm256_fnmadd_ps(mat_a_blk_elems[62], mat_b_col[7], mat_b_rearr[i4][6]);//d = c - (a*b)
|
|
mat_b_rearr[i4][7] = _mm256_fnmadd_ps(mat_a_blk_elems[63], mat_b_col[7], mat_b_rearr[i4][7]);//d = c - (a*b)
|
|
|
|
//end loop of cols
|
|
}
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + i + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + i + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + i + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + i + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + i + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + i + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
i += cs_l;
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + i + 7));
|
|
|
|
k = 0;
|
|
for (i = 0; i < numCols_b; i+=8)
|
|
{
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
|
|
//(Row0): already done
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[k][1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[k][0], mat_b_rearr[k][1]);//d = c - (a*b)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[k][0], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[k][0], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[k][0], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[k][0], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[k][0], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[k][0], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[k][2] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_rearr[k][1], mat_b_rearr[k][2]);//d = c - (a*b)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[8], mat_b_rearr[k][1], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[9], mat_b_rearr[k][1], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[10], mat_b_rearr[k][1], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[11], mat_b_rearr[k][1], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[12], mat_b_rearr[k][1], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[k][3] = _mm256_fnmadd_ps(mat_a_blk_elems[13], mat_b_rearr[k][2], mat_b_rearr[k][3]);//d = c - (a*b)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[14], mat_b_rearr[k][2], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[15], mat_b_rearr[k][2], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[16], mat_b_rearr[k][2], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[17], mat_b_rearr[k][2], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[k][4] = _mm256_fnmadd_ps(mat_a_blk_elems[18], mat_b_rearr[k][3], mat_b_rearr[k][4]);//d = c - (a*b)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[19], mat_b_rearr[k][3], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[20], mat_b_rearr[k][3], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[21], mat_b_rearr[k][3], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[k][5] = _mm256_fnmadd_ps(mat_a_blk_elems[22], mat_b_rearr[k][4], mat_b_rearr[k][5]);//d = c - (a*b)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[23], mat_b_rearr[k][4], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[24], mat_b_rearr[k][4], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[k][6] = _mm256_fnmadd_ps(mat_a_blk_elems[25], mat_b_rearr[k][5], mat_b_rearr[k][6]);//d = c - (a*b)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[26], mat_b_rearr[k][5], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[k][7] = _mm256_fnmadd_ps(mat_a_blk_elems[27], mat_b_rearr[k][6], mat_b_rearr[k][7]);//d = c - (a*b)
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
//Store the computed B columns
|
|
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i, mat_b_rearr[k][0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b) + i), mat_b_rearr[k][1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i), mat_b_rearr[k][2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i), mat_b_rearr[k][3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i), mat_b_rearr[k][4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i), mat_b_rearr[k][5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i), mat_b_rearr[k][6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i), mat_b_rearr[k][7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
}
|
|
|
|
|
|
}
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
#endif //OPT_CACHE_BLOCKING_L1
|
|
|
|
//////////////////////////// AutX=B ///////////////////////
|
|
static void trsm_AutXB_block_allSmallSizedMatrices(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
|
|
|
|
i += cs_b_offset[6];
|
|
ptr_b_dup += cs_b_offset[6];
|
|
//i += 8;
|
|
//ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += cs_l_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += 8;
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += 8;
|
|
i1 += 8;
|
|
i = i1;
|
|
i2 = 0;
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
#endif
|
|
//i = 0;
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
//{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
|
|
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
|
|
#else
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
|
|
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
|
|
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
|
|
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
|
|
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i4 = k >> 3;
|
|
ptr_l_dup++;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
//}
|
|
//i2 += cs_b_offset[6];
|
|
i4 += 8;
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
//{
|
|
//i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
//}
|
|
i += cs_b_offset[6];
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_alpha(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
float ones = 1.0;
|
|
int i, i1, i2, i3, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
__m256 mat_a_diag_inv[8];
|
|
__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
|
|
reciprocal_diags[0] = _mm256_broadcast_ss((float const *)(&ones));
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
//read diag elems of L 16x16 block
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + cs_l_offset[5]);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
reciprocal_diags[1] = reciprocal_diags[0];
|
|
|
|
//pack first 8 diags together
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv[0] = _mm256_unpacklo_ps(mat_a_diag_inv[0], mat_a_diag_inv[0]);
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], mat_a_diag_inv[0]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], mat_a_diag_inv[1]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], mat_a_diag_inv[2]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], mat_a_diag_inv[3]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], mat_a_diag_inv[4]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], mat_a_diag_inv[5]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], mat_a_diag_inv[6]);
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
|
|
|
|
i += cs_b_offset[6];
|
|
ptr_b_dup += cs_b_offset[6];
|
|
//i += 8;
|
|
//ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i3 = 0;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += cs_l_offset[6];
|
|
|
|
//Read next 8x8 block of A to get diag elements
|
|
i3 += 8;
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_l + i3);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l);
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[0]);
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[1]);
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[2]);
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[3]);
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[4]);
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)ptr_l + i3 + cs_l_offset[5]);
|
|
|
|
//pack 8 diags of A together
|
|
reciprocal_diags[0] = reciprocal_diags[1];
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xAA);//diag 0,1
|
|
mat_a_diag_inv[1] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xAA);//diag 2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_blk_elems[4], mat_a_blk_elems[5], 0xAA);//diag 4,5
|
|
mat_a_diag_inv[3] = _mm256_blend_ps(mat_a_blk_elems[6], mat_a_blk_elems[7], 0xAA);//diag 6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[1], 0xCC);//diag 0,1,2,3
|
|
mat_a_diag_inv[2] = _mm256_blend_ps(mat_a_diag_inv[2], mat_a_diag_inv[3], 0xCC);//diag 4,5,6,7
|
|
mat_a_diag_inv[0] = _mm256_blend_ps(mat_a_diag_inv[0], mat_a_diag_inv[2], 0xF0);//diag 0,1,2,3,4,5,6,7
|
|
|
|
//reciprocal of diagnal elements of A :- 0,1,2,3,4,5,6,7
|
|
reciprocal_diags[0] = _mm256_div_ps(reciprocal_diags[0], mat_a_diag_inv[0]);
|
|
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += 8;
|
|
i1 += 8;
|
|
i = i1;
|
|
i2 = 0;
|
|
|
|
//extract diag a00 from a
|
|
mat_a_diag_inv[0] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[0] = _mm256_permute2f128_ps(mat_a_diag_inv[0], mat_a_diag_inv[0], 0x00);
|
|
//mat_a_diag_inv2[0] = _mm256_unpacklo_ps(mat_a_diag_inv2[0], mat_a_diag_inv2[0]);
|
|
|
|
//extract diag a11 from a
|
|
mat_a_diag_inv[1] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[1] = _mm256_permute2f128_ps(mat_a_diag_inv[1], mat_a_diag_inv[1], 0x00);
|
|
//mat_a_diag_inv[1] = _mm256_unpacklo_ps(mat_a_diag_inv[1], mat_a_diag_inv[1]);
|
|
|
|
//extract diag a22 from a
|
|
mat_a_diag_inv[2] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[2] = _mm256_permute2f128_ps(mat_a_diag_inv[2], mat_a_diag_inv[2], 0x00);
|
|
//mat_a_diag_inv[2] = _mm256_unpacklo_ps(mat_a_diag_inv[2], mat_a_diag_inv[2]);
|
|
|
|
//extract diag a33 from a
|
|
mat_a_diag_inv[3] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[3] = _mm256_permute2f128_ps(mat_a_diag_inv[3], mat_a_diag_inv[3], 0x00);
|
|
//mat_a_diag_inv[3] = _mm256_unpacklo_ps(mat_a_diag_inv[3], mat_a_diag_inv[3]);
|
|
|
|
//extract diag a44 from a
|
|
mat_a_diag_inv[4] = _mm256_permute_ps(reciprocal_diags[0], 0x00);
|
|
mat_a_diag_inv[4] = _mm256_permute2f128_ps(mat_a_diag_inv[4], mat_a_diag_inv[4], 0x11);
|
|
//mat_a_diag_inv[4] = _mm256_unpacklo_ps(mat_a_diag_inv[4], mat_a_diag_inv[4]);
|
|
|
|
//extract diag a55 from a
|
|
mat_a_diag_inv[5] = _mm256_permute_ps(reciprocal_diags[0], 0x55);
|
|
mat_a_diag_inv[5] = _mm256_permute2f128_ps(mat_a_diag_inv[5], mat_a_diag_inv[5], 0x11);
|
|
//mat_a_diag_inv[5] = _mm256_unpacklo_ps(mat_a_diag_inv[5], mat_a_diag_inv[5]);
|
|
|
|
//extract diag a66 from a
|
|
mat_a_diag_inv[6] = _mm256_permute_ps(reciprocal_diags[0], 0xAA);
|
|
mat_a_diag_inv[6] = _mm256_permute2f128_ps(mat_a_diag_inv[6], mat_a_diag_inv[6], 0x11);
|
|
//mat_a_diag_inv[6] = _mm256_unpacklo_ps(mat_a_diag_inv[6], mat_a_diag_inv[6]);
|
|
|
|
//extract diag a77 from a
|
|
mat_a_diag_inv[7] = _mm256_permute_ps(reciprocal_diags[0], 0xFF);
|
|
mat_a_diag_inv[7] = _mm256_permute2f128_ps(mat_a_diag_inv[7], mat_a_diag_inv[7], 0x11);
|
|
//mat_a_diag_inv[7] = _mm256_unpacklo_ps(mat_a_diag_inv[7], mat_a_diag_inv[7]);
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
#endif
|
|
|
|
//i = 0;
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
//{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
|
|
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
|
|
#else
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
|
|
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
|
|
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
|
|
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
|
|
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i4 = k >> 3;
|
|
ptr_l_dup++;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
//}
|
|
//i2 += cs_b_offset[6];
|
|
i4 += 8;
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
//{
|
|
//i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): Perform mul operation of reciprocal of L(0,0) element with 1st row elements of B
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], mat_a_diag_inv[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(1,1) element with 2nd row elements of B
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], mat_a_diag_inv[1]);
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(2, 2) element with 3rd row elements of B
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], mat_a_diag_inv[2]);
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(3, 3) element with 4rth row elements of B
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], mat_a_diag_inv[3]);
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(4, 4) element with 4rth row elements of B
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], mat_a_diag_inv[4]);
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
//Perform mul operation of reciprocal of L(5, 5) element with 5th row elements of B
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], mat_a_diag_inv[5]);
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
|
|
|
|
//Perform mul operation of reciprocal of L(6, 6) element with 6th row elements of B
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], mat_a_diag_inv[6]);
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Perform mul operation of reciprocal of L(7, 7) element with 7th row elements of B
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], mat_a_diag_inv[7]);
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
//}
|
|
i += cs_b_offset[6];
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
|
|
//(Row0)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
|
|
|
|
i += cs_b_offset[6];
|
|
ptr_b_dup += cs_b_offset[6];
|
|
//i += 8;
|
|
//ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += cs_l_offset[6];
|
|
|
|
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += 8;
|
|
i1 += 8;
|
|
i = i1;
|
|
i2 = 0;
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
#endif
|
|
|
|
//i = 0;
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
//{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
|
|
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
|
|
#else
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
|
|
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
|
|
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
|
|
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
|
|
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i4 = k >> 3;
|
|
ptr_l_dup++;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
//}
|
|
//i2 += cs_b_offset[6];
|
|
i4 += 8;
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
//{
|
|
//i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): already done
|
|
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
|
|
|
|
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
//}
|
|
i += cs_b_offset[6];
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
|
|
static void trsm_AutXB_block_allSmallSizedMatrices_alpha_unitDiag(float *ptr_l, float *ptr_b, int numRows_lb, int numCols_b, int rs_l, int rs_b, int cs_l, int cs_b, float alpha)
|
|
{
|
|
//float ones = 1.0;
|
|
int i, i1, i2, i4, j, k, l, r;
|
|
int cs_b_offset[7];
|
|
int cs_l_offset[7];
|
|
float *ptr_b_dup, *ptr_l_dup;
|
|
|
|
//57 number of ymm(256 bits) registers used
|
|
__m256 mat_b_col[8];
|
|
__m256 mat_b_rearr[8];
|
|
__m256 mat_a_blk_elems[8];
|
|
//__m256 mat_a_diag_inv[8];
|
|
//__m256 reciprocal_diags[2];
|
|
__m256 alphaReg;
|
|
alphaReg = _mm256_broadcast_ss((float const *)&alpha);
|
|
|
|
// ---> considering that the matrix size is multiple of 16 rows and 8 cols <--- //
|
|
|
|
//L matrix offsets
|
|
cs_l_offset[0] = (cs_l << 1);
|
|
cs_l_offset[1] = cs_l + cs_l_offset[0];
|
|
cs_l_offset[2] = (cs_l << 2);
|
|
cs_l_offset[3] = cs_l + cs_l_offset[2];
|
|
cs_l_offset[4] = cs_l_offset[0] + cs_l_offset[2];
|
|
cs_l_offset[5] = cs_l + cs_l_offset[4];
|
|
cs_l_offset[6] = (cs_l_offset[5] + cs_l);
|
|
|
|
cs_b_offset[0] = (cs_b << 1);
|
|
cs_b_offset[1] = cs_b + cs_b_offset[0];
|
|
cs_b_offset[2] = (cs_b << 2);
|
|
cs_b_offset[3] = cs_b + cs_b_offset[2];
|
|
cs_b_offset[4] = cs_b_offset[0] + cs_b_offset[2];
|
|
cs_b_offset[5] = cs_b + cs_b_offset[4];
|
|
cs_b_offset[6] = (cs_b_offset[5] + cs_b);
|
|
|
|
#if 0
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 4));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 5));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 6));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + 7));
|
|
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 2));
|
|
mat_a_blk_elems[8] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 3));
|
|
mat_a_blk_elems[9] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 4));
|
|
mat_a_blk_elems[10] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 5));
|
|
mat_a_blk_elems[11] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 6));
|
|
mat_a_blk_elems[12] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l + 7));
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[13] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 3));
|
|
mat_a_blk_elems[14] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 4));
|
|
mat_a_blk_elems[15] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 5));
|
|
mat_a_blk_elems[16] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 6));
|
|
mat_a_blk_elems[17] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0] + 7));
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[18] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 4));
|
|
mat_a_blk_elems[19] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 5));
|
|
mat_a_blk_elems[20] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 6));
|
|
mat_a_blk_elems[21] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1] + 7));
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[22] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 5));
|
|
mat_a_blk_elems[23] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 6));
|
|
mat_a_blk_elems[24] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2] + 7));
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[25] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 6));
|
|
mat_a_blk_elems[26] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3] + 7));
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[27] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4] + 7));
|
|
#endif
|
|
|
|
|
|
/***************** first set of 8 rows of B processing starts *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i = 0;
|
|
for (j = 0; j < numCols_b; j += 8)
|
|
{
|
|
/////////////////// Complete Upper 8x8 block trsm of B :- upper 8x8 block of B with upper 8x8 block of A
|
|
//read 8x8 block of B into registers
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
|
|
//(Row0)
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l + cs_l_offset[5]));
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_col[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_col[1]);//d = c - (a*b)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l + 1 + cs_l_offset[5]));
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_col[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_col[2]);//d = c - (a*b)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l + 2 + cs_l_offset[5]));
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_col[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_col[3]);//d = c - (a*b)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l + 3 + cs_l_offset[5]));
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_col[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_col[4]);//d = c - (a*b)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l + 4 + cs_l_offset[5]));
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_col[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_col[5]);//d = c - (a*b)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l + 5 + cs_l_offset[5]));
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_col[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_col[6]);//d = c - (a*b)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l + 6 + cs_l_offset[5]));
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_col[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_col[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup, mat_b_rearr[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)), mat_b_rearr[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0]), mat_b_rearr[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1]), mat_b_rearr[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2]), mat_b_rearr[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3]), mat_b_rearr[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4]), mat_b_rearr[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5]), mat_b_rearr[7]);
|
|
|
|
i += cs_b_offset[6];
|
|
ptr_b_dup += cs_b_offset[6];
|
|
//i += 8;
|
|
//ptr_b_dup += 8;
|
|
}
|
|
|
|
//c = 0;
|
|
/***************** first set of 8 cols of B processing done *****************/
|
|
ptr_b_dup = ptr_b;
|
|
i1 = 0;
|
|
//Start loop for cols of B to be processed in size of blk_width
|
|
for (j = 8; j < numRows_lb; j += 8)//m :- 8x8 block row
|
|
{
|
|
ptr_l += cs_l_offset[6];
|
|
|
|
|
|
//ptr_b += j;
|
|
//ptr_b_dup += 8;
|
|
ptr_b_dup += 8;
|
|
i1 += 8;
|
|
i = i1;
|
|
i2 = 0;
|
|
|
|
for (r = 0; r < numCols_b; r += GEMM_BLK_V1)
|
|
{
|
|
#if GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_col[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_col[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_col[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_col[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_col[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_col[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_col[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_col[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_rearr[0] = _mm256_unpacklo_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_rearr[1] = _mm256_unpacklo_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_rearr[2] = _mm256_unpacklo_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_rearr[3] = _mm256_unpacklo_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_rearr[0] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_rearr[4] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_rearr[1] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_rearr[5] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_col[0] = _mm256_unpackhi_ps(mat_b_col[0], mat_b_col[1]);
|
|
mat_b_col[1] = _mm256_unpackhi_ps(mat_b_col[2], mat_b_col[3]);
|
|
mat_b_col[2] = _mm256_unpackhi_ps(mat_b_col[4], mat_b_col[5]);
|
|
mat_b_col[3] = _mm256_unpackhi_ps(mat_b_col[6], mat_b_col[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_rearr[2] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_rearr[6] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_rearr[3] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_rearr[7] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_b_rearr[0], alphaReg);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_b_rearr[1], alphaReg);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_b_rearr[2], alphaReg);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_b_rearr[3], alphaReg);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_b_rearr[4], alphaReg);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_b_rearr[5], alphaReg);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_b_rearr[6], alphaReg);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_b_rearr[7], alphaReg);
|
|
#endif
|
|
|
|
//i = 0;
|
|
ptr_l_dup = ptr_l;
|
|
i4 = i2;
|
|
for (l = 0; l < j; l += 8) // move across m
|
|
{
|
|
//for (k = 0; k < numCols_b; k += 8) // move across n for the same value of l (index of m)
|
|
//{
|
|
/////////////////// Partial Lower 8x8 block trsm of B
|
|
//Read current 8 cols of B columns from specified 8x8 current-block of B
|
|
mat_a_blk_elems[0] = _mm256_loadu_ps((float const *)ptr_b + i4);
|
|
mat_a_blk_elems[1] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b));
|
|
mat_a_blk_elems[2] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_loadu_ps((float const *)(ptr_b + i4 + cs_b_offset[5]));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_a_blk_elems[0] = _mm256_unpackhi_ps(mat_a_blk_elems[0], mat_a_blk_elems[1]);
|
|
mat_a_blk_elems[1] = _mm256_unpackhi_ps(mat_a_blk_elems[2], mat_a_blk_elems[3]);
|
|
mat_a_blk_elems[2] = _mm256_unpackhi_ps(mat_a_blk_elems[4], mat_a_blk_elems[5]);
|
|
mat_a_blk_elems[3] = _mm256_unpackhi_ps(mat_a_blk_elems[6], mat_a_blk_elems[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_a_blk_elems[4] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x44);
|
|
mat_a_blk_elems[5] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0xEE);
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x44);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0xEE);
|
|
#else
|
|
mat_a_blk_elems[6] = _mm256_shuffle_ps(mat_a_blk_elems[0], mat_a_blk_elems[1], 0x4E);
|
|
mat_a_blk_elems[7] = _mm256_shuffle_ps(mat_a_blk_elems[2], mat_a_blk_elems[3], 0x4E);
|
|
mat_a_blk_elems[4] = _mm256_blend_ps(mat_a_blk_elems[0], mat_a_blk_elems[6], 0xCC);
|
|
mat_a_blk_elems[5] = _mm256_blend_ps(mat_a_blk_elems[1], mat_a_blk_elems[6], 0x33);
|
|
mat_a_blk_elems[6] = _mm256_blend_ps(mat_a_blk_elems[2], mat_a_blk_elems[7], 0xCC);
|
|
mat_a_blk_elems[7] = _mm256_blend_ps(mat_a_blk_elems[3], mat_a_blk_elems[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_a_blk_elems[4], mat_a_blk_elems[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_a_blk_elems[5], mat_a_blk_elems[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Broadcast A8,0 to A15,0 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i4 = k >> 3;
|
|
ptr_l_dup++;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row8): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[0], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_mul_ps(mat_a_blk_elems[0], mat_b_col[0]);
|
|
mat_b_rearr[1] = _mm256_mul_ps(mat_a_blk_elems[1], mat_b_col[0]);
|
|
mat_b_rearr[2] = _mm256_mul_ps(mat_a_blk_elems[2], mat_b_col[0]);
|
|
mat_b_rearr[3] = _mm256_mul_ps(mat_a_blk_elems[3], mat_b_col[0]);
|
|
mat_b_rearr[4] = _mm256_mul_ps(mat_a_blk_elems[4], mat_b_col[0]);
|
|
mat_b_rearr[5] = _mm256_mul_ps(mat_a_blk_elems[5], mat_b_col[0]);
|
|
mat_b_rearr[6] = _mm256_mul_ps(mat_a_blk_elems[6], mat_b_col[0]);
|
|
mat_b_rearr[7] = _mm256_mul_ps(mat_a_blk_elems[7], mat_b_col[0]);
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row9): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[1], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[1], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,2 to A15,2 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row10): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[2], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[2], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[2], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,3 to A15,3 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row11): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[3], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[3], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[3], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[3], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,4 to A15,4 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row12): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[4], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[4], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[4], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[4], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[4], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,5 to A15,5 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row13): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[5], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[5], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[5], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[5], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[5], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[5], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,6 to A15,6 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row14): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[6], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[6], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[6], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[6], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[6], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[6], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[6], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A8,7 to A15,7 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[7] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
ptr_l_dup++;
|
|
#if GEMM_ACCUM_A
|
|
//(Row15): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[0] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[0] = _mm256_fmadd_ps(mat_a_blk_elems[0], mat_b_col[7], mat_b_rearr[0]);//d = c - (a*b)
|
|
mat_b_rearr[1] = _mm256_fmadd_ps(mat_a_blk_elems[1], mat_b_col[7], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fmadd_ps(mat_a_blk_elems[2], mat_b_col[7], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fmadd_ps(mat_a_blk_elems[3], mat_b_col[7], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fmadd_ps(mat_a_blk_elems[4], mat_b_col[7], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fmadd_ps(mat_a_blk_elems[5], mat_b_col[7], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fmadd_ps(mat_a_blk_elems[6], mat_b_col[7], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fmadd_ps(mat_a_blk_elems[7], mat_b_col[7], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//end loop of cols
|
|
//}
|
|
//i2 += cs_b_offset[6];
|
|
i4 += 8;
|
|
}
|
|
//trsm solve
|
|
|
|
k = 0;
|
|
//for (i2 = 0; i2 < numCols_b; i2 += 8)
|
|
//{
|
|
//i2 = i1 + r;
|
|
/////////////////// Complete Lower 8x8 block trsm of B :- lower 8x8 block of B with lower right 8x8 block of A
|
|
#if !GEMM_ACCUM_A
|
|
//Read 8 cols of B columns of Block-to-be-solved
|
|
mat_b_rearr[0] = _mm256_loadu_ps((float const *)ptr_b + i);
|
|
mat_b_rearr[1] = _mm256_loadu_ps((float const *)(ptr_b + cs_b + i));
|
|
mat_b_rearr[2] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[0] + i));
|
|
mat_b_rearr[3] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[1] + i));
|
|
mat_b_rearr[4] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[2] + i));
|
|
mat_b_rearr[5] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[3] + i));
|
|
mat_b_rearr[6] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[4] + i));
|
|
mat_b_rearr[7] = _mm256_loadu_ps((float const *)(ptr_b + cs_b_offset[5] + i));
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
mat_b_col[0] = _mm256_mul_ps(mat_b_col[0], alphaReg);
|
|
mat_b_col[1] = _mm256_mul_ps(mat_b_col[1], alphaReg);
|
|
mat_b_col[2] = _mm256_mul_ps(mat_b_col[2], alphaReg);
|
|
mat_b_col[3] = _mm256_mul_ps(mat_b_col[3], alphaReg);
|
|
mat_b_col[4] = _mm256_mul_ps(mat_b_col[4], alphaReg);
|
|
mat_b_col[5] = _mm256_mul_ps(mat_b_col[5], alphaReg);
|
|
mat_b_col[6] = _mm256_mul_ps(mat_b_col[6], alphaReg);
|
|
mat_b_col[7] = _mm256_mul_ps(mat_b_col[7], alphaReg);
|
|
#endif
|
|
//Broadcast A10 to A70 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[0]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[1]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[2]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[3]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[4]));
|
|
mat_a_blk_elems[6] = _mm256_broadcast_ss((float const *)(ptr_l_dup + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
#if GEMM_ACCUM_A
|
|
//(Row0): already done
|
|
|
|
#else
|
|
mat_b_rearr[0] = _mm256_sub_ps(mat_b_col[0], mat_b_rearr[0]);
|
|
#endif
|
|
|
|
#if GEMM_ACCUM_A
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#else
|
|
mat_b_rearr[1] = _mm256_sub_ps(mat_b_col[1], mat_b_rearr[1]);
|
|
mat_b_rearr[2] = _mm256_sub_ps(mat_b_col[2], mat_b_rearr[2]);
|
|
mat_b_rearr[3] = _mm256_sub_ps(mat_b_col[3], mat_b_rearr[3]);
|
|
mat_b_rearr[4] = _mm256_sub_ps(mat_b_col[4], mat_b_rearr[4]);
|
|
mat_b_rearr[5] = _mm256_sub_ps(mat_b_col[5], mat_b_rearr[5]);
|
|
mat_b_rearr[6] = _mm256_sub_ps(mat_b_col[6], mat_b_rearr[6]);
|
|
mat_b_rearr[7] = _mm256_sub_ps(mat_b_col[7], mat_b_rearr[7]);
|
|
|
|
//(Row1): FMA operations of b1 with elements of indices from (1, 0) uptill (7, 0)
|
|
mat_b_rearr[1] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[0], mat_b_rearr[1]);//d = c - (a*b)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[0], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[0], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[0], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[0], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[0], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[6], mat_b_rearr[0], mat_b_rearr[7]);//d = c - (a*b)
|
|
#endif
|
|
//Broadcast A21 to A71 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[0]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[1]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[2]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[3]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[4]));
|
|
mat_a_blk_elems[5] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 1 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row2): FMA operations of b2 with elements of indices from (2, 0) uptill (7, 0)
|
|
mat_b_rearr[2] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[1], mat_b_rearr[2]);//d = c - (a*b)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[1], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[1], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[1], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[1], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[5], mat_b_rearr[1], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A32 to A72 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[1]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[2]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[3]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[4]));
|
|
mat_a_blk_elems[4] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 2 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row3): FMA operations of b3 with elements of indices from (3, 0) uptill (7, 0)
|
|
mat_b_rearr[3] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[2], mat_b_rearr[3]);//d = c - (a*b)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[2], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[2], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[2], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[4], mat_b_rearr[2], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A43 to A73 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[2]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[3]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[4]));
|
|
mat_a_blk_elems[3] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 3 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row4): FMA operations of b4 with elements of indices from (4, 0) uptill (7, 0)
|
|
mat_b_rearr[4] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[3], mat_b_rearr[4]);//d = c - (a*b)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[3], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[3], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[3], mat_b_rearr[3], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A54 to A74 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[3]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[4]));
|
|
mat_a_blk_elems[2] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 4 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row5): FMA operations of b5 with elements of indices from (5, 0) uptill (7, 0)
|
|
mat_b_rearr[5] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[4], mat_b_rearr[5]);//d = c - (a*b)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[4], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[2], mat_b_rearr[4], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A65 to A75 to registers
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[4]));
|
|
mat_a_blk_elems[1] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 5 + cs_l_offset[5]));
|
|
//i += cs_l;
|
|
|
|
|
|
|
|
//(Row6): FMA operations of b6 with elements of indices from (6, 0) uptill (7, 0)
|
|
mat_b_rearr[6] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[5], mat_b_rearr[6]);//d = c - (a*b)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[1], mat_b_rearr[5], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
//Broadcast A76 to register
|
|
mat_a_blk_elems[0] = _mm256_broadcast_ss((float const *)(ptr_l_dup + 6 + cs_l_offset[5]));
|
|
|
|
|
|
|
|
//(Row7): FMA operations of b7 with elements of index (7, 0)
|
|
mat_b_rearr[7] = _mm256_fnmadd_ps(mat_a_blk_elems[0], mat_b_rearr[6], mat_b_rearr[7]);//d = c - (a*b)
|
|
|
|
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
|
|
/* transpose steps start */
|
|
////unpacklow////
|
|
mat_b_col[0] = _mm256_unpacklo_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_col[1] = _mm256_unpacklo_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_col[2] = _mm256_unpacklo_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_col[3] = _mm256_unpacklo_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange low elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_col[4] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x44);
|
|
mat_b_col[5] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0xEE);
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x44);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0xEE);
|
|
#else
|
|
mat_b_col[6] = _mm256_shuffle_ps(mat_b_col[0], mat_b_col[1], 0x4E);
|
|
mat_b_col[7] = _mm256_shuffle_ps(mat_b_col[2], mat_b_col[3], 0x4E);
|
|
mat_b_col[4] = _mm256_blend_ps(mat_b_col[0], mat_b_col[6], 0xCC);
|
|
mat_b_col[5] = _mm256_blend_ps(mat_b_col[1], mat_b_col[6], 0x33);
|
|
mat_b_col[6] = _mm256_blend_ps(mat_b_col[2], mat_b_col[7], 0xCC);
|
|
mat_b_col[7] = _mm256_blend_ps(mat_b_col[3], mat_b_col[7], 0x33);
|
|
#endif
|
|
//Merge rearranged low elements into complete rows
|
|
mat_b_col[0] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x20);
|
|
mat_b_col[4] = _mm256_permute2f128_ps(mat_b_col[4], mat_b_col[6], 0x31);
|
|
mat_b_col[1] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x20);
|
|
mat_b_col[5] = _mm256_permute2f128_ps(mat_b_col[5], mat_b_col[7], 0x31);
|
|
|
|
////unpackhigh////
|
|
mat_b_rearr[0] = _mm256_unpackhi_ps(mat_b_rearr[0], mat_b_rearr[1]);
|
|
mat_b_rearr[1] = _mm256_unpackhi_ps(mat_b_rearr[2], mat_b_rearr[3]);
|
|
mat_b_rearr[2] = _mm256_unpackhi_ps(mat_b_rearr[4], mat_b_rearr[5]);
|
|
mat_b_rearr[3] = _mm256_unpackhi_ps(mat_b_rearr[6], mat_b_rearr[7]);
|
|
|
|
//Rearrange high elements
|
|
#if REARRANGE_SHFL == 1
|
|
mat_b_rearr[4] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x44);
|
|
mat_b_rearr[5] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0xEE);
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x44);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0xEE);
|
|
#else
|
|
mat_b_rearr[6] = _mm256_shuffle_ps(mat_b_rearr[0], mat_b_rearr[1], 0x4E);
|
|
mat_b_rearr[7] = _mm256_shuffle_ps(mat_b_rearr[2], mat_b_rearr[3], 0x4E);
|
|
mat_b_rearr[4] = _mm256_blend_ps(mat_b_rearr[0], mat_b_rearr[6], 0xCC);
|
|
mat_b_rearr[5] = _mm256_blend_ps(mat_b_rearr[1], mat_b_rearr[6], 0x33);
|
|
mat_b_rearr[6] = _mm256_blend_ps(mat_b_rearr[2], mat_b_rearr[7], 0xCC);
|
|
mat_b_rearr[7] = _mm256_blend_ps(mat_b_rearr[3], mat_b_rearr[7], 0x33);
|
|
#endif
|
|
|
|
//Merge rearranged high elements into complete rows
|
|
mat_b_col[2] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x20);
|
|
mat_b_col[6] = _mm256_permute2f128_ps(mat_b_rearr[4], mat_b_rearr[6], 0x31);
|
|
mat_b_col[3] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x20);
|
|
mat_b_col[7] = _mm256_permute2f128_ps(mat_b_rearr[5], mat_b_rearr[7], 0x31);
|
|
/* transpose steps end */
|
|
|
|
//Store the computed B columns
|
|
_mm256_storeu_ps((float *)ptr_b_dup + i2, mat_b_col[0]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + (cs_b)+i2), mat_b_col[1]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[0] + i2), mat_b_col[2]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[1] + i2), mat_b_col[3]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[2] + i2), mat_b_col[4]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[3] + i2), mat_b_col[5]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[4] + i2), mat_b_col[6]);
|
|
_mm256_storeu_ps((float *)(ptr_b_dup + cs_b_offset[5] + i2), mat_b_col[7]);
|
|
//printf("writing B => m[%d], n[%d], [%f]\n", j, k, *(ptr_b_dup + k));
|
|
k++;
|
|
//}
|
|
i += cs_b_offset[6];
|
|
i2 += cs_b_offset[6];
|
|
}
|
|
} //numRows of A
|
|
///////////////////loop ends /////////////////////
|
|
}
|
|
#endif
|