From 8c537b0cd564c099f00d8a6355d584736beef60f Mon Sep 17 00:00:00 2001 From: Aayush Kumar Date: Thu, 23 Mar 2023 06:14:39 +0000 Subject: [PATCH] Added DTRSM Small Path AVX512 based LLNN/LUTN Variant Kernels - 8x8 kernels are used for DTRSM SMALL - Implemented fringe cases with below block sizes 8x8, 8x4, 8x3, 8x2, 8x1 4x8, 4x4, 4x3, 4x2, 4x1 3x8, 3x4, 3x3, 3x2, 3x1 2x8, 2x4, 2x3, 2x2, 2x1 1x8, 1x4, 1x3, 1x2, 1x1 AMD-Internal: [CPUPL-2745] Change-Id: I58d28912bddbaadb404052c0f3449ebbe3c97b68 --- frame/compat/bla_trsm_amd.c | 11 +- frame/include/bli_trsm_small_ref.h | 129 ++ kernels/zen/3/bli_trsm_small.c | 142 +- kernels/zen4/3/bli_trsm_small_AVX512.c | 2686 +++++++++++++++++++++++- 4 files changed, 2811 insertions(+), 157 deletions(-) create mode 100644 frame/include/bli_trsm_small_ref.h diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 2b0c4bc5e..f58aa3710 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -1059,10 +1059,9 @@ void dtrsm_blis_impl { case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) - // check if variant is RUN[N/U] or RLT[N/U] // this is a temporary fix, will be removed when all variants are added - if( (blis_side == BLIS_RIGHT) && - ((n0 > 300) && (m0 > 50))) + if( ((blis_side == BLIS_RIGHT) && ((n0 > 300) && (m0 > 50))) || + ((blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_NO_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_TRANSPOSE) ) ) && ((n0 != 30 && n0 !=60 ) && (m0 > 50))) ) { ker_ft = bli_trsm_small_AVX512; } @@ -1089,13 +1088,13 @@ void dtrsm_blis_impl { case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) - if( blis_side == BLIS_RIGHT ) + if ( (blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_NO_TRANSPOSE) ) )) { - ker_ft = bli_trsm_small_mt_AVX512; + ker_ft = bli_trsm_small_mt; } else { - ker_ft = bli_trsm_small_mt; + ker_ft = bli_trsm_small_mt_AVX512; } break; #endif// BLIS_KERNELS_ZEN4 diff --git a/frame/include/bli_trsm_small_ref.h b/frame/include/bli_trsm_small_ref.h new file mode 100644 index 000000000..715db884e --- /dev/null +++ b/frame/include/bli_trsm_small_ref.h @@ -0,0 +1,129 @@ +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + #define DIAG_ELE_INV_OPS(a, b) (a / b) + #define DIAG_ELE_EVAL_OPS(a, b) (a * b) +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + #define DIAG_ELE_INV_OPS(a, b) (a * b) + #define DIAG_ELE_EVAL_OPS(a, b) (a / b) +#endif + +// reference code for LUTN +BLIS_INLINE err_t dtrsm_AutXB_ref + ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool unitDiagonal + ) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + double lkk_inv = 1.0; + if (!unitDiagonal) + lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]); + for (j = 0; j < N; j++) + { + B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv); + for (i = k + 1; i < M; i++) + { + B[i + j * ldb] -= A[i * lda + k] * B[k + j * ldb]; + } + } + } // k -loop + return BLIS_SUCCESS; +} + +// reference code for LLNN +BLIS_INLINE err_t dtrsm_AlXB_ref + ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag + ) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + double lkk_inv = 1.0; + if (!is_unitdiag) + lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]); + for (j = 0; j < N; j++) + { + B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv); + for (i = k + 1; i < M; i++) + { + B[i + j * ldb] -= A[i + k * lda] * B[k + j * ldb]; + } + } + } // k -loop + return BLIS_SUCCESS; +} + +// reference code for LUNN +BLIS_INLINE err_t dtrsm_AuXB_ref + ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag + ) +{ + dim_t i, j, k; + for (k = M - 1; k >= 0; k--) + { + double lkk_inv = 1.0; + if (!is_unitdiag) + lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]); + for (j = N - 1; j >= 0; j--) + { + B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv); + for (i = k - 1; i >= 0; i--) + { + B[i + j * ldb] -= A[i + k * lda] * B[k + j * ldb]; + } + } + } // k -loop + return BLIS_SUCCESS; +} // end of function + +// reference code for LLTN +BLIS_INLINE err_t dtrsm_AltXB_ref + ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag + ) +{ + dim_t i, j, k; + for (k = M - 1; k >= 0; k--) + { + double lkk_inv = 1.0; + if (!is_unitdiag) + lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]); + for (j = N - 1; j >= 0; j--) + { + B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv); + for (i = k - 1; i >= 0; i--) + { + B[i + j * ldb] -= A[i * lda + k] * B[k + j * ldb]; + } + } + } // k -loop + return BLIS_SUCCESS; +} // end of function diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 8cb7bb786..15dec8733 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -35,6 +35,7 @@ #include "blis.h" #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM #include "immintrin.h" +#include "bli_trsm_small_ref.h" #define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL @@ -107,18 +108,6 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB cntl_t* cntl ); -//AX = B; A is lower triangular; transpose; -//double precision; non-unit diagonal -BLIS_INLINE err_t dtrsm_AltXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag -); /* * ZTRSM kernel declaration */ @@ -248,41 +237,6 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB #define DIAG_ELE_EVAL_OPS(a,b) (a / b) #endif -/* - * Reference implementations - * ToDo: We can combine all these reference implementation - into a macro -*/ -//A'X = B; A is upper triangular; transpose; -//non-unitDiagonal double precision -BLIS_INLINE err_t dtrsm_AutXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool unitDiagonal -) -{ - dim_t i, j, k; - for (k = 0; k < M; k++) - { - double lkk_inv = 1.0; - if(!unitDiagonal) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = 0; j < N; j++) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb] , lkk_inv); - for (i = k+1; i < M; i++) - { - B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; - } - } - }// k -loop - return BLIS_SUCCESS; -}// end of function - /* * Reference implementations * ToDo: We can combine all these reference implementation @@ -318,37 +272,6 @@ BLIS_INLINE err_t strsm_AutXB_ref return BLIS_SUCCESS; }// end of function -/* TRSM scalar code for the case AX = alpha * B - * A is upper-triangular, non-unit-diagonal - * Dimensions: A: mxm X: mxn B:mxn - */ -BLIS_INLINE err_t dtrsm_AuXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag -) -{ - dim_t i, j, k; - for (k = M-1; k >= 0; k--) - { - double lkk_inv = 1.0; - if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = N -1; j >= 0; j--) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); - for (i = k-1; i >=0; i--) - { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; - } - } - }// k -loop - return BLIS_SUCCESS; -}// end of function /* TRSM scalar code for the case AX = alpha * B * A is upper-triangular, non-unit-diagonal @@ -382,37 +305,6 @@ BLIS_INLINE err_t strsm_AuXB_ref return BLIS_SUCCESS; }// end of function -/* TRSM scalar code for the case AX = alpha * B - * A is lower-triangular, non-unit-diagonal, no transpose - * Dimensions: A: mxm X: mxn B:mxn - */ -BLIS_INLINE err_t dtrsm_AlXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag -) -{ - dim_t i, j, k; - for (k = 0; k < M; k++) - { - double lkk_inv = 1.0; - if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = 0; j < N; j++) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); - for (i = k+1; i < M; i++) - { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; - } - } - }// k -loop - return BLIS_SUCCESS; -}// end of function /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, non-unit-diagonal, no transpose @@ -446,38 +338,6 @@ BLIS_INLINE err_t strsm_AlXB_ref return BLIS_SUCCESS; }// end of function -/* TRSM scalar code for the case AX = alpha * B - * A is lower-triangular, non-unit-diagonal, transpose - * Dimensions: A: mxm X: mxn B:mxn - */ -BLIS_INLINE err_t dtrsm_AltXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag -) -{ - dim_t i, j, k; - for (k = M-1; k >= 0; k--) - { - double lkk_inv = 1.0; - if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = N -1; j >= 0; j--) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); - for (i = k-1; i >=0; i--) - { - B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; - } - } - }// k -loop - return BLIS_SUCCESS; -}// end of function - /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, non-unit-diagonal, transpose * Dimensions: A: mxm X: mxn B:mxn diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c index b9c4bd3b4..639ada81e 100644 --- a/kernels/zen4/3/bli_trsm_small_AVX512.c +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -27,20 +27,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include "blis.h" +#include "bli_trsm_small_ref.h" #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM #include "immintrin.h" #define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - #define DIAG_ELE_INV_OPS(a, b) (a / b) - #define DIAG_ELE_EVAL_OPS(a, b) (a * b) -#endif -#ifdef BLIS_DISABLE_TRSM_PREINVERSION - #define DIAG_ELE_INV_OPS(a, b) (a * b) - #define DIAG_ELE_EVAL_OPS(a, b) (a / b) -#endif #ifdef BLIS_DISABLE_TRSM_PREINVERSION #define DTRSM_SMALL_DIV_OR_SCALE _mm256_div_pd @@ -86,6 +79,17 @@ ymm30 = _mm256_setzero_pd(); \ ymm31 = _mm256_setzero_pd(); +#define BLIS_SET_YMM_REG_ZEROS_FOR_LEFT \ + ymm8 = _mm256_setzero_pd(); \ + ymm9 = _mm256_setzero_pd(); \ + ymm10 = _mm256_setzero_pd(); \ + ymm11 = _mm256_setzero_pd(); \ + ymm12 = _mm256_setzero_pd(); \ + ymm13 = _mm256_setzero_pd(); \ + ymm14 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + ymm16 = _mm256_setzero_pd(); \ + #define BLIS_SET_ZMM_REG_ZEROS \ zmm0 = _mm512_setzero_pd(); \ zmm1 = _mm512_setzero_pd(); \ @@ -143,6 +147,7 @@ typedef err_t (*trsmsmall_ker_ft) cntl_t* cntl ); + /* Pack a block of 8xk from input buffer into packed buffer directly or after transpose based on input params @@ -163,7 +168,104 @@ BLIS_INLINE void bli_dtrsm_small_pack_avx512 __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; if (side == 'L' || side == 'l') { - return; // BLIS_NOT_YET_IMPLEMENTED + /*Left case is 8xk*/ + if (trans) + { + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13; + for (dim_t x = 0; x < size; x += mr) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + 4)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); + ymm12 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3)); + ymm13 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(pbuff), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda * 3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda * 7), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6)); + ymm12 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7)); + ymm13 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7 + 4)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(pbuff + 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 7), ymm9); + + inbuf += mr; + pbuff += mr * mr; + } + } + else + for (dim_t x = 0; x < size; x++) + { + zmm0 = _mm512_loadu_pd((double const *)(inbuf)); + _mm512_storeu_pd((double *)(pbuff), zmm0); + inbuf += cs_a; + pbuff += p_lda; + } } else if (side == 'R' || side == 'r') { @@ -623,6 +725,7 @@ err_t bli_trsm_small_mt_AVX512 } // End of function #endif + // region - GEMM DTRSM for right variants #define BLIS_DTRSM_SMALL_GEMM_8nx8m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ @@ -6498,6 +6601,622 @@ else if ( n_remainder == 2) return BLIS_SUCCESS; } +// region - 8x8 transpose for left variants +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8(b11, cs_b, AlphaVal) \ + zmm8 = _mm512_set1_pd(AlphaVal); \ + zmm0 = _mm512_loadu_pd((double const *)b11 + (cs_b * 0)); \ + zmm1 = _mm512_loadu_pd((double const *)b11 + (cs_b * 1)); \ + zmm2 = _mm512_loadu_pd((double const *)b11 + (cs_b * 2)); \ + zmm3 = _mm512_loadu_pd((double const *)b11 + (cs_b * 3)); \ + zmm0 = _mm512_fmsub_pd(zmm0, zmm8, zmm9); \ + zmm1 = _mm512_fmsub_pd(zmm1, zmm8, zmm10); \ + zmm2 = _mm512_fmsub_pd(zmm2, zmm8, zmm11); \ + zmm3 = _mm512_fmsub_pd(zmm3, zmm8, zmm12); \ + \ + zmm4 = _mm512_loadu_pd((double const *)b11 + (cs_b * 4)); \ + zmm5 = _mm512_loadu_pd((double const *)b11 + (cs_b * 5)); \ + zmm6 = _mm512_loadu_pd((double const *)b11 + (cs_b * 6)); \ + zmm7 = _mm512_loadu_pd((double const *)b11 + (cs_b * 7)); \ + zmm4 = _mm512_fmsub_pd(zmm4, zmm8, zmm13); \ + zmm5 = _mm512_fmsub_pd(zmm5, zmm8, zmm14); \ + zmm6 = _mm512_fmsub_pd(zmm6, zmm8, zmm15); \ + zmm7 = _mm512_fmsub_pd(zmm7, zmm8, zmm16); \ + /*Stage1*/ \ + zmm17 = _mm512_unpacklo_pd(zmm0, zmm1); \ + zmm18 = _mm512_unpacklo_pd(zmm2, zmm3); \ + zmm19 = _mm512_unpacklo_pd(zmm4, zmm5); \ + zmm20 = _mm512_unpacklo_pd(zmm6, zmm7); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b10001000); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b10001000); \ + /*Stage3 1,5*/ \ + zmm9 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm13 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b11011101); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b11011101); \ + /*Stage3 3,7*/ \ + zmm11 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm15 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage1*/ \ + zmm17 = _mm512_unpackhi_pd(zmm0, zmm1); \ + zmm18 = _mm512_unpackhi_pd(zmm2, zmm3); \ + zmm19 = _mm512_unpackhi_pd(zmm4, zmm5); \ + zmm20 = _mm512_unpackhi_pd(zmm6, zmm7); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b10001000); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b10001000); \ + /*Stage3 2,6*/ \ + zmm10 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm14 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b11011101); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b11011101); \ + /*Stage3 4,8*/ \ + zmm12 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm16 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); + +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8(b11, cs_b, AlphaVal) \ + ymm8 = _mm256_broadcast_sd((double const *)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + (cs_b *1))); \ + ymm2 = _mm256_loadu_pd((double const *)(b11 + (cs_b *2))); \ + ymm3 = _mm256_loadu_pd((double const *)(b11 + (cs_b *3))); \ + ymm0 = _mm256_fmsub_pd(ymm0, ymm8, ymm9); \ + ymm1 = _mm256_fmsub_pd(ymm1, ymm8, ymm10); \ + ymm2 = _mm256_fmsub_pd(ymm2, ymm8, ymm11); \ + ymm3 = _mm256_fmsub_pd(ymm3, ymm8, ymm12); \ + \ + ymm10 = _mm256_unpacklo_pd(ymm0, ymm1); \ + ymm12 = _mm256_unpacklo_pd(ymm2, ymm3); \ + ymm9 = _mm256_permute2f128_pd(ymm10,ymm12,0x20); \ + ymm11 = _mm256_permute2f128_pd(ymm10,ymm12,0x31); \ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); \ + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); \ + ymm10 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); \ + ymm12 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); \ + \ + ymm8 = _mm256_broadcast_sd((double const *)(&AlphaVal)); \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b *4))); \ + ymm1 = _mm256_loadu_pd((double const *)(b11 + (cs_b *5))); \ + ymm2 = _mm256_loadu_pd((double const *)(b11 + (cs_b *6))); \ + ymm3 = _mm256_loadu_pd((double const *)(b11 + (cs_b *7))); \ + ymm0 = _mm256_fmsub_pd(ymm0, ymm8, ymm13); \ + ymm1 = _mm256_fmsub_pd(ymm1, ymm8, ymm14); \ + ymm2 = _mm256_fmsub_pd(ymm2, ymm8, ymm15); \ + ymm3 = _mm256_fmsub_pd(ymm3, ymm8, ymm16); \ + \ + ymm14 = _mm256_unpacklo_pd(ymm0, ymm1); \ + ymm16 = _mm256_unpacklo_pd(ymm2, ymm3); \ + ymm13 = _mm256_permute2f128_pd(ymm14,ymm16,0x20); \ + ymm15 = _mm256_permute2f128_pd(ymm14,ymm16,0x31); \ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); \ + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); \ + ymm14 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); \ + ymm16 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); \ + +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8_AND_STORE(b11, cs_b) \ + ymm1 = _mm256_unpacklo_pd(ymm9, ymm10); \ + ymm3 = _mm256_unpacklo_pd(ymm11, ymm12); \ + \ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); \ + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); \ + \ + /*unpack high*/\ + ymm9 = _mm256_unpackhi_pd(ymm9, ymm10); \ + ymm10 = _mm256_unpackhi_pd(ymm11, ymm12); \ + \ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm9, ymm10, 0x20); \ + ymm3 = _mm256_permute2f128_pd(ymm9, ymm10, 0x31); \ + \ + /*unpacklow*/\ + ymm5 = _mm256_unpacklo_pd(ymm13, ymm14); \ + ymm7 = _mm256_unpacklo_pd(ymm15, ymm16); \ + \ + /*rearrange low elements*/\ + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); \ + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); \ + \ + /*unpack high*/\ + ymm13 = _mm256_unpackhi_pd(ymm13, ymm14); \ + ymm14 = _mm256_unpackhi_pd(ymm15, ymm16); \ + \ + /*rearrange high elements*/\ + ymm5 = _mm256_permute2f128_pd(ymm13, ymm14, 0x20); \ + ymm7 = _mm256_permute2f128_pd(ymm13, ymm14, 0x31); \ + +/* +zmm9 [0] zmm9 [1] zmm9 [2] zmm9 [3] zmm9 [4] zmm9 [5] zmm9 [6] zmm9 [7] +zmm10[0] zmm10[1] zmm10[2] zmm10[3] zmm10[4] zmm10[5] zmm10[6] zmm10[7] +zmm11[0] zmm11[1] zmm11[2] zmm11[3] zmm11[4] zmm11[5] zmm11[6] zmm11[7] +zmm12[0] zmm12[1] zmm12[2] zmm12[3] zmm12[4] zmm12[5] zmm12[6] zmm12[7] +zmm13[0] zmm13[1] zmm13[2] zmm13[3] zmm13[4] zmm13[5] zmm13[6] zmm13[7] +zmm14[0] zmm14[1] zmm14[2] zmm14[3] zmm14[4] zmm14[5] zmm14[6] zmm14[7] +zmm15[0] zmm15[1] zmm15[2] zmm15[3] zmm15[4] zmm15[5] zmm15[6] zmm15[7] +zmm16[0] zmm16[1] zmm16[2] zmm16[3] zmm16[4] zmm16[5] zmm16[6] zmm16[7] + +Stage1 +zmm17 = zmm10[1] zmm9 [1] zmm10[3] zmm9 [3] zmm10[5] zmm9 [5] zmm10[7] zmm9 [7] +zmm18 = zmm12[1] zmm11[1] zmm12[3] zmm11[3] zmm12[5] zmm11[5] zmm12[7] zmm11[7] +zmm19 = zmm14[1] zmm13[1] zmm14[3] zmm13[3] zmm14[5] zmm13[5] zmm14[7] zmm13[7] +zmm20 = zmm16[1] zmm15[1] zmm16[3] zmm15[3] zmm16[5] zmm15[5] zmm16[7] zmm15[7] + +Stage2 +zmm21 = zmm12[3] zmm11[3] zmm12[7] zmm11[7] zmm10[3] zmm9 [3] zmm10[7] zmm9 [7] +zmm22 = zmm16[3] zmm15[3] zmm16[7] zmm15[7] zmm14[3] zmm13[3] zmm14[7] zmm13[7] + +Stage3 1,5 +zmm0 = zmm16[7] zmm15[7] zmm14[7] zmm13[7] zmm12[7] zmm11[7] zmm10[7] zmm9 [7] +zmm4 = zmm16[3] zmm15[3] zmm14[3] zmm13[3] zmm12[3] zmm11[3] zmm10[3] zmm9 [3] + +Stage2 +zmm21 = zmm12[1] zmm11[1] zmm12[5] zmm11[5] zmm10[1] zmm9 [1] zmm10[5] zmm9 [5] +zmm22 = zmm16[1] zmm15[1] zmm16[5] zmm15[5] zmm14[1] zmm13[1] zmm14[5] zmm13[5] + +Stage3 3,7 +zmm2 = zmm16[5] zmm15[5] zmm14[5] zmm13[5] zmm12[5] zmm11[5] zmm10[5] zmm9 [5] +zmm6 = zmm16[1] zmm15[1] zmm14[1] zmm13[1] zmm12[1] zmm11[1] zmm10[1] zmm9 [1] + +Stage1 +zmm17 = zmm10[0] zmm9 [0] zmm10[2] zmm9 [2] zmm10[4] zmm9 [4] zmm10[6] zmm9 [6] +zmm18 = zmm12[0] zmm11[0] zmm12[2] zmm11[2] zmm12[4] zmm11[4] zmm12[6] zmm11[6] +zmm19 = zmm14[0] zmm13[0] zmm14[2] zmm13[2] zmm14[4] zmm13[4] zmm14[6] zmm13[6] +zmm20 = zmm16[0] zmm15[0] zmm16[2] zmm15[2] zmm16[4] zmm15[4] zmm16[6] zmm15[6] + +Stage2 +zmm21 = zmm12[2] zmm11[2] zmm12[6] zmm11[6] zmm10[2] zmm9 [2] zmm10[6] zmm9 [6] +zmm22 = zmm16[2] zmm15[2] zmm16[6] zmm15[6] zmm14[2] zmm13[2] zmm14[6] zmm13[6] + +Stage3 2,6 +zmm1 = zmm16[6] zmm15[6] zmm14[6] zmm13[6] zmm12[6] zmm11[6] zmm10[6] zmm9 [6] +zmm5 = zmm16[2] zmm15[2] zmm14[2] zmm13[2] zmm12[2] zmm11[2] zmm10[2] zmm9 [2] + +Stage2 +zmm21 = zmm12[0] zmm11[0] zmm12[4] zmm11[4] zmm10[0] zmm9 [0] zmm10[4] zmm9 [4] +zmm22 = zmm16[0] zmm15[0] zmm16[4] zmm15[4] zmm14[0] zmm13[0] zmm14[4] zmm13[4] + +Stage3 4,8 +zmm3 = zmm16[4] zmm15[4] zmm14[4] zmm13[4] zmm12[4] zmm11[4] zmm10[4] zmm9 [4] +zmm7 = zmm16[0] zmm15[0] zmm14[0] zmm13[0] zmm12[0] zmm11[0] zmm10[0] zmm9 [0] +*/ +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8_AND_STORE(b11, cs_b) \ + /*Stage1*/ \ + zmm17 = _mm512_unpacklo_pd(zmm9, zmm10); \ + zmm18 = _mm512_unpacklo_pd(zmm11, zmm12); \ + zmm19 = _mm512_unpacklo_pd(zmm13, zmm14); \ + zmm20 = _mm512_unpacklo_pd(zmm15, zmm16); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b10001000); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b10001000); \ + /*Stage3 1,5*/ \ + zmm0 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm4 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b11011101); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b11011101); \ + /*Stage3 3,7*/ \ + zmm2 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm6 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage1*/ \ + zmm17 = _mm512_unpackhi_pd(zmm9, zmm10); \ + zmm18 = _mm512_unpackhi_pd(zmm11, zmm12); \ + zmm19 = _mm512_unpackhi_pd(zmm13, zmm14); \ + zmm20 = _mm512_unpackhi_pd(zmm15, zmm16); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b10001000); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b10001000); \ + /*Stage3 2,6*/ \ + zmm1 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm5 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b11011101); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b11011101); \ + /*Stage3 4,8*/ \ + zmm3 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm7 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); + +// endregion - 8x8 transpose for left variants + +// region - GEMM DTRSM for left variants + +#define BLIS_DTRSM_SMALL_GEMM_8mx8n_AVX512(a10, b01, cs_b, p_lda, k_iter, b11) \ + /*k_iter -= 8; */ \ + int itrCount = (k_iter / 2); \ + int itr = itrCount; \ + int itr2 = k_iter - itrCount; \ + double *b01_2 = b01 + itr; \ + double *a10_2 = a10 + (p_lda * itr); \ + for (; itr > 0; itr--) \ + { \ + zmm0 = _mm512_loadu_pd((double const *)a10); \ + \ + zmm1 = _mm512_set1_pd(*(b01 + cs_b * 0)); \ + zmm2 = _mm512_set1_pd(*(b01 + cs_b * 1)); \ + zmm3 = _mm512_set1_pd(*(b01 + cs_b * 2)); \ + zmm4 = _mm512_set1_pd(*(b01 + cs_b * 3)); \ + zmm5 = _mm512_set1_pd(*(b01 + cs_b * 4)); \ + zmm6 = _mm512_set1_pd(*(b01 + cs_b * 5)); \ + zmm7 = _mm512_set1_pd(*(b01 + cs_b * 6)); \ + zmm8 = _mm512_set1_pd(*(b01 + cs_b * 7)); \ + \ + _mm_prefetch((b01 + 8), _MM_HINT_T0); \ + zmm9 = _mm512_fmadd_pd(zmm1, zmm0, zmm9); \ + zmm10 = _mm512_fmadd_pd(zmm2, zmm0, zmm10); \ + zmm11 = _mm512_fmadd_pd(zmm3, zmm0, zmm11); \ + zmm12 = _mm512_fmadd_pd(zmm4, zmm0, zmm12); \ + zmm13 = _mm512_fmadd_pd(zmm5, zmm0, zmm13); \ + zmm14 = _mm512_fmadd_pd(zmm6, zmm0, zmm14); \ + zmm15 = _mm512_fmadd_pd(zmm7, zmm0, zmm15); \ + zmm16 = _mm512_fmadd_pd(zmm8, zmm0, zmm16); \ + \ + b01 += 1; \ + a10 += p_lda; \ + } \ + for (; itr2 > 0; itr2--) \ + { \ + zmm23 = _mm512_loadu_pd((double const *)a10_2); \ + \ + zmm17 = _mm512_set1_pd(*(b01_2 + cs_b * 0)); \ + zmm18 = _mm512_set1_pd(*(b01_2 + cs_b * 1)); \ + zmm19 = _mm512_set1_pd(*(b01_2 + cs_b * 2)); \ + zmm20 = _mm512_set1_pd(*(b01_2 + cs_b * 3)); \ + zmm21 = _mm512_set1_pd(*(b01_2 + cs_b * 4)); \ + zmm22 = _mm512_set1_pd(*(b01_2 + cs_b * 5)); \ + \ + _mm_prefetch((b01_2 + 8), _MM_HINT_T0); \ + zmm24 = _mm512_fmadd_pd(zmm17, zmm23, zmm24); \ + zmm17 = _mm512_set1_pd(*(b01_2 + cs_b * 6)); \ + zmm25 = _mm512_fmadd_pd(zmm18, zmm23, zmm25); \ + zmm18 = _mm512_set1_pd(*(b01_2 + cs_b * 7)); \ + zmm26 = _mm512_fmadd_pd(zmm19, zmm23, zmm26); \ + zmm27 = _mm512_fmadd_pd(zmm20, zmm23, zmm27); \ + zmm28 = _mm512_fmadd_pd(zmm21, zmm23, zmm28); \ + zmm29 = _mm512_fmadd_pd(zmm22, zmm23, zmm29); \ + zmm30 = _mm512_fmadd_pd(zmm17, zmm23, zmm30); \ + zmm31 = _mm512_fmadd_pd(zmm18, zmm23, zmm31); \ + \ + b01_2 += 1; \ + a10_2 += p_lda; \ + } \ + _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + zmm9 = _mm512_add_pd(zmm9, zmm24); \ + _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + zmm10 = _mm512_add_pd(zmm10, zmm25); \ + _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + zmm11 = _mm512_add_pd(zmm11, zmm26); \ + _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + zmm12 = _mm512_add_pd(zmm12, zmm27); \ + _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + zmm13 = _mm512_add_pd(zmm13, zmm28); \ + _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + zmm14 = _mm512_add_pd(zmm14, zmm29); \ + _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + zmm15 = _mm512_add_pd(zmm15, zmm30); \ + _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + zmm16 = _mm512_add_pd(zmm16, zmm31); + +#define BLIS_DTRSM_SMALL_GEMM_8mx4n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01)); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 2))); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 3))); \ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); \ + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx3n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 2))); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx2n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx1n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm1 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 0))); \ + ymm2 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 1))); \ + ymm3 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 2))); \ + ymm4 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 3))); \ + ymm5 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 4))); \ + ymm6 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 5))); \ + ymm7 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 6))); \ + ymm8 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 7))); \ + \ + _mm_prefetch((b01 + 4 * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_fmadd_pd (ymm1, ymm0, ymm9); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ + ymm12 = _mm256_fmadd_pd(ymm4, ymm0, ymm12); \ + ymm13 = _mm256_fmadd_pd(ymm5, ymm0, ymm13); \ + ymm14 = _mm256_fmadd_pd(ymm6, ymm0, ymm14); \ + ymm15 = _mm256_fmadd_pd(ymm7, ymm0, ymm15); \ + ymm16 = _mm256_fmadd_pd(ymm8, ymm0, ymm16); \ + \ + b01 += 1; \ + a10 += p_lda; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 2))); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 3))); \ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 2))); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +// endregion - GEMM DTRSM for left variants + +// region - pre/post DTRSM for left variants + +#define BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0) + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); \ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 1) + 2)); \ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); \ + ymm2 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 2) + 2)); \ + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); \ + _mm_storel_pd((b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm8, 1)); \ + xmm5 = _mm256_castpd256_pd128(ymm9); \ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); \ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); \ + xmm5 = _mm256_castpd256_pd128(ymm10); \ + _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); \ + _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0) + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); \ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 1) + 2)); \ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); \ + _mm_storel_pd((b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm8, 1)); \ + xmm5 = _mm256_castpd256_pd128(ymm9); \ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); \ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0) + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11), xmm5); \ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm8, 1)); + +#define BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 0))); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 1))); \ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 2))); \ + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 0)), xmm5); \ + xmm5 = _mm256_castpd256_pd128(ymm9); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 1)), xmm5); \ + xmm5 = _mm256_castpd256_pd128(ymm10); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 2)), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 0))); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 1))); \ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 0)), xmm5); \ + xmm5 = _mm256_castpd256_pd128(ymm9); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 1)), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 0))); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 0)), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0))); \ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 1))); \ + ymm2 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 2))); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); \ + \ + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0)); \ + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9, 0)); \ + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0))); \ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 1))); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + \ + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0)); \ + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0)); \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + \ + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0)); + // LLNN - LUTN BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 ( @@ -6508,9 +7227,1956 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 cntl_t* cntl ) { - return BLIS_NOT_YET_IMPLEMENTED; + + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); // number of columns + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8, d_nr = 8; + + // Swap rs_a & cs_a in case of non-tranpose. + if (transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; + dim_t k_iter; + + double AlphaVal = *(double *)AlphaObj->buffer; + double *L = bli_obj_buffer_at_off(a); // pointer to matrix A + double *B = bli_obj_buffer_at_off(b); // pointer to matrix B + + double *a10, *a11, *b01, *b11; // pointers for GEMM and TRSM blocks + + + double ones = 1.0; + + const gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; // pointer to A01 pack buffer + double d11_pack[d_mr] __attribute__((aligned(64))); // buffer for diagonal A pack + rntm_t rntm; + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ((d_mr * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if (FALSE == bli_mem_is_alloc(&local_mem_buf_A_s)) + return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if (NULL == D_A_pack) + return BLIS_NULL_POINTER; + } + bool is_unitdiag = bli_obj_has_unit_diag(a); + __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7, zmm8, zmm9, zmm10, zmm11; + __m512d zmm12, zmm13, zmm14, zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512d zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15, ymm16; + __m128d xmm5; + xmm5 = _mm_setzero_pd(); + + /* + Performs solving TRSM for 8 columns at a time from 0 to m/8 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for (i = 0; (i + d_mr - 1) < m; i += d_mr) + { + a10 = L + (i * cs_a); + a11 = L + (i * rs_a) + (i * cs_a); + + dim_t p_lda = d_mr; + + if (transa) + { + /* + Pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next iteration + until it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack_avx512('L', i, 1, a10, cs_a, D_A_pack, p_lda, d_mr); + /* + Pack 8 diagonal elements of A block into an array + a. This helps to utilize cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element_avx512(is_unitdiag, a11, cs_a, d11_pack, d_mr); + } + else + { + bli_dtrsm_small_pack_avx512('L', i, 0, a10, rs_a, D_A_pack, p_lda, d_mr); + dtrsm_small_pack_diag_element_avx512(is_unitdiag, a11, rs_a, d11_pack, d_mr); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x8 block size + along n dimension for every d_nr columns of B01 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + + for (j = 0; j < n - d_nr + 1; j += d_nr) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); //pointer to block of A to be used for TRSM + b01 = B + j * cs_b; //pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; //pointer to block of B to be used for TRSM + k_iter = i; + + BLIS_SET_ZMM_REG_ZEROS + /* + Perform GEMM between a10 and b01 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8mx8n_AVX512(a10, b01, cs_b, p_lda, k_iter, b11) + /* + Load b11 of size 8x8 and multiply with alpha + Add the GEMM output and perform in register transpose of b11 + to perform TRSM operation. + */ + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8(b11, cs_b, AlphaVal) + + /* + Compute 8x8 TRSM block by using GEMM block output in register + a. The 8x8 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : zmm9 zmm10 zmm11 zmm12 zmm13 zmm14 zmm15 zmm16 + b. Towards the end TRSM output will be stored back into b11 + */ + // extract a00 + zmm0 = _mm512_set1_pd(*(d11_pack + 0)); + zmm9 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm9, zmm0); + + // extract a11 + zmm1 = _mm512_set1_pd(*(d11_pack + 1)); + zmm2 = _mm512_set1_pd(*(a11 + (1 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm2, zmm9, zmm10); + zmm3 = _mm512_set1_pd(*(a11 + (2 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm3, zmm9, zmm11); + zmm4 = _mm512_set1_pd(*(a11 + (3 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm4, zmm9, zmm12); + zmm5 = _mm512_set1_pd(*(a11 + (4 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm5, zmm9, zmm13); + zmm6 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm6, zmm9, zmm14); + zmm7 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm7, zmm9, zmm15); + zmm8 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm8, zmm9, zmm16); + zmm10 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm10, zmm1); + a11 += rs_a; + + // extract a22 + zmm0 = _mm512_set1_pd(*(d11_pack + 2)); + zmm2 = _mm512_set1_pd(*(a11 + (2 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm2, zmm10, zmm11); + zmm3 = _mm512_set1_pd(*(a11 + (3 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm3, zmm10, zmm12); + zmm4 = _mm512_set1_pd(*(a11 + (4 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm4, zmm10, zmm13); + zmm5 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm5, zmm10, zmm14); + zmm6 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm6, zmm10, zmm15); + zmm7 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm7, zmm10, zmm16); + zmm11 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm11, zmm0); + a11 += rs_a; + + // extract a33 + zmm1 = _mm512_set1_pd(*(d11_pack + 3)); + zmm2 = _mm512_set1_pd(*(a11 + (3 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm2, zmm11, zmm12); + zmm3 = _mm512_set1_pd(*(a11 + (4 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm3, zmm11, zmm13); + zmm4 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm4, zmm11, zmm14); + zmm5 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm5, zmm11, zmm15); + zmm6 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm6, zmm11, zmm16); + zmm12 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm12, zmm1); + a11 += rs_a; + + // extract a44 + zmm0 = _mm512_set1_pd(*(d11_pack + 4)); + zmm2 = _mm512_set1_pd(*(a11 + (4 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm2, zmm12, zmm13); + zmm3 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm3, zmm12, zmm14); + zmm4 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm4, zmm12, zmm15); + zmm5 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm5, zmm12, zmm16); + zmm13 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm13, zmm0); + a11 += rs_a; + + // extract a55 + zmm1 = _mm512_set1_pd(*(d11_pack + 5)); + zmm2 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm2, zmm13, zmm14); + zmm3 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm3, zmm13, zmm15); + zmm4 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm4, zmm13, zmm16); + zmm14 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm14, zmm1); + a11 += rs_a; + + // extract a66 + zmm0 = _mm512_set1_pd(*(d11_pack + 6)); + zmm2 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm2, zmm14, zmm15); + zmm3 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm3, zmm14, zmm16); + zmm15 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm15, zmm0); + a11 += rs_a; + + // extract a77 + zmm1 = _mm512_set1_pd(*(d11_pack + 7)); + zmm2 = _mm512_set1_pd(*(a11 + 7 * cs_a)); + zmm16 = _mm512_fnmadd_pd(zmm2, zmm15, zmm16); + zmm16 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm16, zmm1); + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8_AND_STORE(b11, cs_b) + _mm512_storeu_pd((double *)(b11 + (cs_b * 0)), zmm0); + _mm512_storeu_pd((double *)(b11 + (cs_b * 1)), zmm1); + _mm512_storeu_pd((double *)(b11 + (cs_b * 2)), zmm2); + _mm512_storeu_pd((double *)(b11 + (cs_b * 3)), zmm3); + _mm512_storeu_pd((double *)(b11 + (cs_b * 4)), zmm4); + _mm512_storeu_pd((double *)(b11 + (cs_b * 5)), zmm5); + _mm512_storeu_pd((double *)(b11 + (cs_b * 6)), zmm6); + _mm512_storeu_pd((double *)(b11 + (cs_b * 7)), zmm7); + } + dim_t n_rem = n - j; + if (n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 0))); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 1))); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 3))); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 0) + 4)); + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 1) + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2) + 4)); + // B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 3) + 4)); + // B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); // B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); // B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); // B11[0-3][7] * alpha -= B01[0-3][7] + + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); // B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); // B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13, ymm15, 0x20); // B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13, ymm15, 0x31); // B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); // B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); // B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); // B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); // B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1))); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2))); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2))); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + // perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + // perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + // perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + // extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 7)); + + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + // perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); + // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); + // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + // B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + // B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); + // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); + // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + // B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + // B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + // B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); // store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); // store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); // store B11[7][0-3] + + n_rem -= 4; + j += 4; + } + + if (n_rem) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + j * cs_b; // pointer to block of B to be used for GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); + // B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + // B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); // B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); // B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13, ymm15, 0x20); // B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13, ymm15, 0x31); // B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); // B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); // B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); // B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); // B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1))); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2))); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2))); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + // perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 7)); + + a11 += rs_a; + + // extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + // perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + // perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + // extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 7)); + + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + // perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); // B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); // B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); // B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); // B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); // B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if (3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); // store B11[6][0-3] + } + else if (2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + } + else if (1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + } + } + } + + //======================M remainder cases================================ + dim_t m_rem = m - i; + + if (m_rem >= 4) // implementation for reamainder rows(when 'M' is not a multiple of d_mr) + { + a10 = L + (i * cs_a); // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + + if (transa) + { + for (dim_t x = 0; x < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < i; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3 + 3)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + rs_a * 1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + rs_a * 2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + rs_a * 3 + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (j = 0; (j + d_nr - 1) < n; j += d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM operation to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter); + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8(b11, cs_b, AlphaVal); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 0)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 * cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm3, ymm9, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm3, ymm13, ymm15); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm9, ymm12); + ymm16 = _mm256_fnmadd_pd(ymm4, ymm13, ymm16); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + a11 += rs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm3, ymm10, ymm12); + ymm16 = _mm256_fnmadd_pd(ymm3, ymm14, ymm16); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm0); + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm16 = _mm256_fnmadd_pd(ymm2, ymm15, ymm16); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm1); + + a11 += rs_a; + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8_AND_STORE(b11, cs_b) + + _mm256_storeu_pd((double *)(b11 + 0 * cs_b), ymm0); + _mm256_storeu_pd((double *)(b11 + 1 * cs_b), ymm1); + _mm256_storeu_pd((double *)(b11 + 2 * cs_b), ymm2); + _mm256_storeu_pd((double *)(b11 + 3 * cs_b), ymm3); + _mm256_storeu_pd((double *)(b11 + 4 * cs_b), ymm4); + _mm256_storeu_pd((double *)(b11 + 5 * cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + 6 * cs_b), ymm6); + _mm256_storeu_pd((double *)(b11 + 7 * cs_b), ymm7); + } + + dim_t n_rem = n - j; + if (n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + j * cs_b; // pointer to block of B to be used for GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm2,ymm12,ymm13); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm3,ymm12,ymm14); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); + + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); + + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); // store B11[3][0-3] + + n_rem -= 4; + j += 4; + } + if (n_rem) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + j * cs_b; // pointer to block of B to be used for GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if (3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + } + else if (2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + } + else if (1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + } + } + m_rem -= 4; + i += 4; + } + + if (m_rem) + { + a10 = L + (i * cs_a); // pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if (3 == m_rem) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < i; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + + // cols + for (j = 0; (j + d_nr - 1) < n; j += d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + + ymm0 = _mm256_broadcast_sd((double const *)(&AlphaVal)); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm4 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3 + 2)); + ymm4 = _mm256_insertf128_pd(ymm4, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 4)); + ymm5 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 4 + 2)); + ymm5 = _mm256_insertf128_pd(ymm5, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 5)); + ymm6 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 5 + 2)); + ymm6 = _mm256_insertf128_pd(ymm6, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 6)); + ymm7 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 6 + 2)); + ymm7 = _mm256_insertf128_pd(ymm7, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 7)); + ymm8 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 7 + 2)); + ymm8 = _mm256_insertf128_pd(ymm8, xmm5, 0); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storeu_pd((double *)(b11 + cs_b * 0), _mm256_extractf128_pd(ymm9, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm10, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm11, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm12, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm13, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm14, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 6), _mm256_extractf128_pd(ymm15, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 7), _mm256_extractf128_pd(ymm16, 0)); + + _mm_storel_pd((double *)(b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm10, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm11, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm12, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 4 + 2), _mm256_extractf128_pd(ymm13, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 5 + 2), _mm256_extractf128_pd(ymm14, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 6 + 2), _mm256_extractf128_pd(ymm15, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 7 + 2), _mm256_extractf128_pd(ymm16, 1)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 8, rs_a, cs_b, is_unitdiag); + } + + dim_t n_rem = n - j; + if ((n_rem >= 4)) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11, 1)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j += 4; + } + + if (n_rem) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if (2 == m_rem) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < i; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + // cols + for (j = 0; (j + d_nr - 1) < n; j += d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + ymm0 = _mm256_broadcast_sd((double const *)(&AlphaVal)); + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm4 = _mm256_insertf128_pd(ymm4, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 4)); + ymm5 = _mm256_insertf128_pd(ymm5, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 5)); + ymm6 = _mm256_insertf128_pd(ymm6, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 6)); + ymm7 = _mm256_insertf128_pd(ymm7, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 7)); + ymm8 = _mm256_insertf128_pd(ymm8, xmm5, 0); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storeu_pd((double *)(b11 + cs_b * 0), _mm256_extractf128_pd(ymm9, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm10, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm11, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm12, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm13, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm14, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 6), _mm256_extractf128_pd(ymm15, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 7), _mm256_extractf128_pd(ymm16, 0)); + + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 8, rs_a, cs_b, is_unitdiag); + } + + dim_t n_rem = n - j; + if ((n_rem >= 4)) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT +; + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j += 4; + } + if (n_rem) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -= 2; + i += 2; + } + else if (1 == m_rem) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < i; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + // cols + for (j = 0; (j + d_nr - 1) < n; j += d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter); + + /// GEMM code ends/// + ymm0 = _mm256_broadcast_sd((double const*)(&AlphaVal)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0)); + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1)); + ymm3 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2)); + ymm4 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 3)); + ymm5 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 4)); + ymm6 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 5)); + ymm7 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 6)); + ymm8 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 7)); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storel_pd((double *)(b11 + cs_b * 0), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm10, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm11, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm12, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm13, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm14, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 6), _mm256_extractf128_pd(ymm15, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 7), _mm256_extractf128_pd(ymm16, 0)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 8, rs_a, cs_b, is_unitdiag); + } + dim_t n_rem = n - j; + if ((n_rem >= 4)) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm8, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm10, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm11, 0)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j += 4; + } + + if (n_rem) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AutXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -= 1; + i += 1; + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc(&local_mem_buf_A_s)) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + return BLIS_SUCCESS; } + // LUNN LUTN BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512 (