diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 2b0c4bc5e..f58aa3710 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -1059,10 +1059,9 @@ void dtrsm_blis_impl { case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) - // check if variant is RUN[N/U] or RLT[N/U] // this is a temporary fix, will be removed when all variants are added - if( (blis_side == BLIS_RIGHT) && - ((n0 > 300) && (m0 > 50))) + if( ((blis_side == BLIS_RIGHT) && ((n0 > 300) && (m0 > 50))) || + ((blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_NO_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_TRANSPOSE) ) ) && ((n0 != 30 && n0 !=60 ) && (m0 > 50))) ) { ker_ft = bli_trsm_small_AVX512; } @@ -1089,13 +1088,13 @@ void dtrsm_blis_impl { case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) - if( blis_side == BLIS_RIGHT ) + if ( (blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_NO_TRANSPOSE) ) )) { - ker_ft = bli_trsm_small_mt_AVX512; + ker_ft = bli_trsm_small_mt; } else { - ker_ft = bli_trsm_small_mt; + ker_ft = bli_trsm_small_mt_AVX512; } break; #endif// BLIS_KERNELS_ZEN4 diff --git a/frame/include/bli_trsm_small_ref.h b/frame/include/bli_trsm_small_ref.h new file mode 100644 index 000000000..715db884e --- /dev/null +++ b/frame/include/bli_trsm_small_ref.h @@ -0,0 +1,129 @@ +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + #define DIAG_ELE_INV_OPS(a, b) (a / b) + #define DIAG_ELE_EVAL_OPS(a, b) (a * b) +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + #define DIAG_ELE_INV_OPS(a, b) (a * b) + #define DIAG_ELE_EVAL_OPS(a, b) (a / b) +#endif + +// reference code for LUTN +BLIS_INLINE err_t dtrsm_AutXB_ref + ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool unitDiagonal + ) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + double lkk_inv = 1.0; + if (!unitDiagonal) + lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]); + for (j = 0; j < N; j++) + { + B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv); + for (i = k + 1; i < M; i++) + { + B[i + j * ldb] -= A[i * lda + k] * B[k + j * ldb]; + } + } + } // k -loop + return BLIS_SUCCESS; +} + +// reference code for LLNN +BLIS_INLINE err_t dtrsm_AlXB_ref + ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag + ) +{ + dim_t i, j, k; + for (k = 0; k < M; k++) + { + double lkk_inv = 1.0; + if (!is_unitdiag) + lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]); + for (j = 0; j < N; j++) + { + B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv); + for (i = k + 1; i < M; i++) + { + B[i + j * ldb] -= A[i + k * lda] * B[k + j * ldb]; + } + } + } // k -loop + return BLIS_SUCCESS; +} + +// reference code for LUNN +BLIS_INLINE err_t dtrsm_AuXB_ref + ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag + ) +{ + dim_t i, j, k; + for (k = M - 1; k >= 0; k--) + { + double lkk_inv = 1.0; + if (!is_unitdiag) + lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]); + for (j = N - 1; j >= 0; j--) + { + B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv); + for (i = k - 1; i >= 0; i--) + { + B[i + j * ldb] -= A[i + k * lda] * B[k + j * ldb]; + } + } + } // k -loop + return BLIS_SUCCESS; +} // end of function + +// reference code for LLTN +BLIS_INLINE err_t dtrsm_AltXB_ref + ( + double *A, + double *B, + dim_t M, + dim_t N, + dim_t lda, + dim_t ldb, + bool is_unitdiag + ) +{ + dim_t i, j, k; + for (k = M - 1; k >= 0; k--) + { + double lkk_inv = 1.0; + if (!is_unitdiag) + lkk_inv = DIAG_ELE_INV_OPS(lkk_inv, A[k + k * lda]); + for (j = N - 1; j >= 0; j--) + { + B[k + j * ldb] = DIAG_ELE_EVAL_OPS(B[k + j * ldb], lkk_inv); + for (i = k - 1; i >= 0; i--) + { + B[i + j * ldb] -= A[i * lda + k] * B[k + j * ldb]; + } + } + } // k -loop + return BLIS_SUCCESS; +} // end of function diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 8cb7bb786..15dec8733 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -35,6 +35,7 @@ #include "blis.h" #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM #include "immintrin.h" +#include "bli_trsm_small_ref.h" #define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL @@ -107,18 +108,6 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB cntl_t* cntl ); -//AX = B; A is lower triangular; transpose; -//double precision; non-unit diagonal -BLIS_INLINE err_t dtrsm_AltXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag -); /* * ZTRSM kernel declaration */ @@ -248,41 +237,6 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB #define DIAG_ELE_EVAL_OPS(a,b) (a / b) #endif -/* - * Reference implementations - * ToDo: We can combine all these reference implementation - into a macro -*/ -//A'X = B; A is upper triangular; transpose; -//non-unitDiagonal double precision -BLIS_INLINE err_t dtrsm_AutXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool unitDiagonal -) -{ - dim_t i, j, k; - for (k = 0; k < M; k++) - { - double lkk_inv = 1.0; - if(!unitDiagonal) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = 0; j < N; j++) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb] , lkk_inv); - for (i = k+1; i < M; i++) - { - B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; - } - } - }// k -loop - return BLIS_SUCCESS; -}// end of function - /* * Reference implementations * ToDo: We can combine all these reference implementation @@ -318,37 +272,6 @@ BLIS_INLINE err_t strsm_AutXB_ref return BLIS_SUCCESS; }// end of function -/* TRSM scalar code for the case AX = alpha * B - * A is upper-triangular, non-unit-diagonal - * Dimensions: A: mxm X: mxn B:mxn - */ -BLIS_INLINE err_t dtrsm_AuXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag -) -{ - dim_t i, j, k; - for (k = M-1; k >= 0; k--) - { - double lkk_inv = 1.0; - if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = N -1; j >= 0; j--) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); - for (i = k-1; i >=0; i--) - { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; - } - } - }// k -loop - return BLIS_SUCCESS; -}// end of function /* TRSM scalar code for the case AX = alpha * B * A is upper-triangular, non-unit-diagonal @@ -382,37 +305,6 @@ BLIS_INLINE err_t strsm_AuXB_ref return BLIS_SUCCESS; }// end of function -/* TRSM scalar code for the case AX = alpha * B - * A is lower-triangular, non-unit-diagonal, no transpose - * Dimensions: A: mxm X: mxn B:mxn - */ -BLIS_INLINE err_t dtrsm_AlXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag -) -{ - dim_t i, j, k; - for (k = 0; k < M; k++) - { - double lkk_inv = 1.0; - if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = 0; j < N; j++) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); - for (i = k+1; i < M; i++) - { - B[i + j*ldb] -= A[i + k*lda] * B[k + j*ldb]; - } - } - }// k -loop - return BLIS_SUCCESS; -}// end of function /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, non-unit-diagonal, no transpose @@ -446,38 +338,6 @@ BLIS_INLINE err_t strsm_AlXB_ref return BLIS_SUCCESS; }// end of function -/* TRSM scalar code for the case AX = alpha * B - * A is lower-triangular, non-unit-diagonal, transpose - * Dimensions: A: mxm X: mxn B:mxn - */ -BLIS_INLINE err_t dtrsm_AltXB_ref -( - double *A, - double *B, - dim_t M, - dim_t N, - dim_t lda, - dim_t ldb, - bool is_unitdiag -) -{ - dim_t i, j, k; - for (k = M-1; k >= 0; k--) - { - double lkk_inv = 1.0; - if(!is_unitdiag) lkk_inv = DIAG_ELE_INV_OPS(lkk_inv,A[k+k*lda]); - for (j = N -1; j >= 0; j--) - { - B[k + j*ldb] = DIAG_ELE_EVAL_OPS(B[k + j*ldb],lkk_inv); - for (i = k-1; i >=0; i--) - { - B[i + j*ldb] -= A[i*lda + k] * B[k + j*ldb]; - } - } - }// k -loop - return BLIS_SUCCESS; -}// end of function - /* TRSM scalar code for the case AX = alpha * B * A is lower-triangular, non-unit-diagonal, transpose * Dimensions: A: mxm X: mxn B:mxn diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c index b9c4bd3b4..639ada81e 100644 --- a/kernels/zen4/3/bli_trsm_small_AVX512.c +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -27,20 +27,13 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ #include "blis.h" +#include "bli_trsm_small_ref.h" #ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM #include "immintrin.h" #define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL -#ifdef BLIS_ENABLE_TRSM_PREINVERSION - #define DIAG_ELE_INV_OPS(a, b) (a / b) - #define DIAG_ELE_EVAL_OPS(a, b) (a * b) -#endif -#ifdef BLIS_DISABLE_TRSM_PREINVERSION - #define DIAG_ELE_INV_OPS(a, b) (a * b) - #define DIAG_ELE_EVAL_OPS(a, b) (a / b) -#endif #ifdef BLIS_DISABLE_TRSM_PREINVERSION #define DTRSM_SMALL_DIV_OR_SCALE _mm256_div_pd @@ -86,6 +79,17 @@ ymm30 = _mm256_setzero_pd(); \ ymm31 = _mm256_setzero_pd(); +#define BLIS_SET_YMM_REG_ZEROS_FOR_LEFT \ + ymm8 = _mm256_setzero_pd(); \ + ymm9 = _mm256_setzero_pd(); \ + ymm10 = _mm256_setzero_pd(); \ + ymm11 = _mm256_setzero_pd(); \ + ymm12 = _mm256_setzero_pd(); \ + ymm13 = _mm256_setzero_pd(); \ + ymm14 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + ymm16 = _mm256_setzero_pd(); \ + #define BLIS_SET_ZMM_REG_ZEROS \ zmm0 = _mm512_setzero_pd(); \ zmm1 = _mm512_setzero_pd(); \ @@ -143,6 +147,7 @@ typedef err_t (*trsmsmall_ker_ft) cntl_t* cntl ); + /* Pack a block of 8xk from input buffer into packed buffer directly or after transpose based on input params @@ -163,7 +168,104 @@ BLIS_INLINE void bli_dtrsm_small_pack_avx512 __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; if (side == 'L' || side == 'l') { - return; // BLIS_NOT_YET_IMPLEMENTED + /*Left case is 8xk*/ + if (trans) + { + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13; + for (dim_t x = 0; x < size; x += mr) + { + ymm0 = _mm256_loadu_pd((double const *)(inbuf)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + 4)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 2)); + ymm12 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 3)); + ymm13 = _mm256_loadu_pd((double const *)(inbuf + 4 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(pbuff), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda * 3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(pbuff + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(pbuff + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(pbuff + p_lda * 7), ymm9); + + ymm0 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4)); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 4 + 4)); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 5 + 4)); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6)); + ymm12 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 6 + 4)); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7)); + ymm13 = _mm256_loadu_pd((double const *)(inbuf + cs_a * 7 + 4)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(pbuff + 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 3), ymm9); + + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 4), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 5), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 6), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + p_lda * 7), ymm9); + + inbuf += mr; + pbuff += mr * mr; + } + } + else + for (dim_t x = 0; x < size; x++) + { + zmm0 = _mm512_loadu_pd((double const *)(inbuf)); + _mm512_storeu_pd((double *)(pbuff), zmm0); + inbuf += cs_a; + pbuff += p_lda; + } } else if (side == 'R' || side == 'r') { @@ -623,6 +725,7 @@ err_t bli_trsm_small_mt_AVX512 } // End of function #endif + // region - GEMM DTRSM for right variants #define BLIS_DTRSM_SMALL_GEMM_8nx8m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ @@ -6498,6 +6601,622 @@ else if ( n_remainder == 2) return BLIS_SUCCESS; } +// region - 8x8 transpose for left variants +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8(b11, cs_b, AlphaVal) \ + zmm8 = _mm512_set1_pd(AlphaVal); \ + zmm0 = _mm512_loadu_pd((double const *)b11 + (cs_b * 0)); \ + zmm1 = _mm512_loadu_pd((double const *)b11 + (cs_b * 1)); \ + zmm2 = _mm512_loadu_pd((double const *)b11 + (cs_b * 2)); \ + zmm3 = _mm512_loadu_pd((double const *)b11 + (cs_b * 3)); \ + zmm0 = _mm512_fmsub_pd(zmm0, zmm8, zmm9); \ + zmm1 = _mm512_fmsub_pd(zmm1, zmm8, zmm10); \ + zmm2 = _mm512_fmsub_pd(zmm2, zmm8, zmm11); \ + zmm3 = _mm512_fmsub_pd(zmm3, zmm8, zmm12); \ + \ + zmm4 = _mm512_loadu_pd((double const *)b11 + (cs_b * 4)); \ + zmm5 = _mm512_loadu_pd((double const *)b11 + (cs_b * 5)); \ + zmm6 = _mm512_loadu_pd((double const *)b11 + (cs_b * 6)); \ + zmm7 = _mm512_loadu_pd((double const *)b11 + (cs_b * 7)); \ + zmm4 = _mm512_fmsub_pd(zmm4, zmm8, zmm13); \ + zmm5 = _mm512_fmsub_pd(zmm5, zmm8, zmm14); \ + zmm6 = _mm512_fmsub_pd(zmm6, zmm8, zmm15); \ + zmm7 = _mm512_fmsub_pd(zmm7, zmm8, zmm16); \ + /*Stage1*/ \ + zmm17 = _mm512_unpacklo_pd(zmm0, zmm1); \ + zmm18 = _mm512_unpacklo_pd(zmm2, zmm3); \ + zmm19 = _mm512_unpacklo_pd(zmm4, zmm5); \ + zmm20 = _mm512_unpacklo_pd(zmm6, zmm7); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b10001000); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b10001000); \ + /*Stage3 1,5*/ \ + zmm9 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm13 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b11011101); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b11011101); \ + /*Stage3 3,7*/ \ + zmm11 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm15 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage1*/ \ + zmm17 = _mm512_unpackhi_pd(zmm0, zmm1); \ + zmm18 = _mm512_unpackhi_pd(zmm2, zmm3); \ + zmm19 = _mm512_unpackhi_pd(zmm4, zmm5); \ + zmm20 = _mm512_unpackhi_pd(zmm6, zmm7); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b10001000); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b10001000); \ + /*Stage3 2,6*/ \ + zmm10 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm14 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b11011101); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b11011101); \ + /*Stage3 4,8*/ \ + zmm12 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm16 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); + +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8(b11, cs_b, AlphaVal) \ + ymm8 = _mm256_broadcast_sd((double const *)(&AlphaVal));\ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11));\ + ymm1 = _mm256_loadu_pd((double const *)(b11 + (cs_b *1))); \ + ymm2 = _mm256_loadu_pd((double const *)(b11 + (cs_b *2))); \ + ymm3 = _mm256_loadu_pd((double const *)(b11 + (cs_b *3))); \ + ymm0 = _mm256_fmsub_pd(ymm0, ymm8, ymm9); \ + ymm1 = _mm256_fmsub_pd(ymm1, ymm8, ymm10); \ + ymm2 = _mm256_fmsub_pd(ymm2, ymm8, ymm11); \ + ymm3 = _mm256_fmsub_pd(ymm3, ymm8, ymm12); \ + \ + ymm10 = _mm256_unpacklo_pd(ymm0, ymm1); \ + ymm12 = _mm256_unpacklo_pd(ymm2, ymm3); \ + ymm9 = _mm256_permute2f128_pd(ymm10,ymm12,0x20); \ + ymm11 = _mm256_permute2f128_pd(ymm10,ymm12,0x31); \ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); \ + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); \ + ymm10 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); \ + ymm12 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); \ + \ + ymm8 = _mm256_broadcast_sd((double const *)(&AlphaVal)); \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b *4))); \ + ymm1 = _mm256_loadu_pd((double const *)(b11 + (cs_b *5))); \ + ymm2 = _mm256_loadu_pd((double const *)(b11 + (cs_b *6))); \ + ymm3 = _mm256_loadu_pd((double const *)(b11 + (cs_b *7))); \ + ymm0 = _mm256_fmsub_pd(ymm0, ymm8, ymm13); \ + ymm1 = _mm256_fmsub_pd(ymm1, ymm8, ymm14); \ + ymm2 = _mm256_fmsub_pd(ymm2, ymm8, ymm15); \ + ymm3 = _mm256_fmsub_pd(ymm3, ymm8, ymm16); \ + \ + ymm14 = _mm256_unpacklo_pd(ymm0, ymm1); \ + ymm16 = _mm256_unpacklo_pd(ymm2, ymm3); \ + ymm13 = _mm256_permute2f128_pd(ymm14,ymm16,0x20); \ + ymm15 = _mm256_permute2f128_pd(ymm14,ymm16,0x31); \ + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); \ + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); \ + ymm14 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); \ + ymm16 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); \ + +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8_AND_STORE(b11, cs_b) \ + ymm1 = _mm256_unpacklo_pd(ymm9, ymm10); \ + ymm3 = _mm256_unpacklo_pd(ymm11, ymm12); \ + \ + /*rearrange low elements*/\ + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); \ + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); \ + \ + /*unpack high*/\ + ymm9 = _mm256_unpackhi_pd(ymm9, ymm10); \ + ymm10 = _mm256_unpackhi_pd(ymm11, ymm12); \ + \ + /*rearrange high elements*/\ + ymm1 = _mm256_permute2f128_pd(ymm9, ymm10, 0x20); \ + ymm3 = _mm256_permute2f128_pd(ymm9, ymm10, 0x31); \ + \ + /*unpacklow*/\ + ymm5 = _mm256_unpacklo_pd(ymm13, ymm14); \ + ymm7 = _mm256_unpacklo_pd(ymm15, ymm16); \ + \ + /*rearrange low elements*/\ + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); \ + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); \ + \ + /*unpack high*/\ + ymm13 = _mm256_unpackhi_pd(ymm13, ymm14); \ + ymm14 = _mm256_unpackhi_pd(ymm15, ymm16); \ + \ + /*rearrange high elements*/\ + ymm5 = _mm256_permute2f128_pd(ymm13, ymm14, 0x20); \ + ymm7 = _mm256_permute2f128_pd(ymm13, ymm14, 0x31); \ + +/* +zmm9 [0] zmm9 [1] zmm9 [2] zmm9 [3] zmm9 [4] zmm9 [5] zmm9 [6] zmm9 [7] +zmm10[0] zmm10[1] zmm10[2] zmm10[3] zmm10[4] zmm10[5] zmm10[6] zmm10[7] +zmm11[0] zmm11[1] zmm11[2] zmm11[3] zmm11[4] zmm11[5] zmm11[6] zmm11[7] +zmm12[0] zmm12[1] zmm12[2] zmm12[3] zmm12[4] zmm12[5] zmm12[6] zmm12[7] +zmm13[0] zmm13[1] zmm13[2] zmm13[3] zmm13[4] zmm13[5] zmm13[6] zmm13[7] +zmm14[0] zmm14[1] zmm14[2] zmm14[3] zmm14[4] zmm14[5] zmm14[6] zmm14[7] +zmm15[0] zmm15[1] zmm15[2] zmm15[3] zmm15[4] zmm15[5] zmm15[6] zmm15[7] +zmm16[0] zmm16[1] zmm16[2] zmm16[3] zmm16[4] zmm16[5] zmm16[6] zmm16[7] + +Stage1 +zmm17 = zmm10[1] zmm9 [1] zmm10[3] zmm9 [3] zmm10[5] zmm9 [5] zmm10[7] zmm9 [7] +zmm18 = zmm12[1] zmm11[1] zmm12[3] zmm11[3] zmm12[5] zmm11[5] zmm12[7] zmm11[7] +zmm19 = zmm14[1] zmm13[1] zmm14[3] zmm13[3] zmm14[5] zmm13[5] zmm14[7] zmm13[7] +zmm20 = zmm16[1] zmm15[1] zmm16[3] zmm15[3] zmm16[5] zmm15[5] zmm16[7] zmm15[7] + +Stage2 +zmm21 = zmm12[3] zmm11[3] zmm12[7] zmm11[7] zmm10[3] zmm9 [3] zmm10[7] zmm9 [7] +zmm22 = zmm16[3] zmm15[3] zmm16[7] zmm15[7] zmm14[3] zmm13[3] zmm14[7] zmm13[7] + +Stage3 1,5 +zmm0 = zmm16[7] zmm15[7] zmm14[7] zmm13[7] zmm12[7] zmm11[7] zmm10[7] zmm9 [7] +zmm4 = zmm16[3] zmm15[3] zmm14[3] zmm13[3] zmm12[3] zmm11[3] zmm10[3] zmm9 [3] + +Stage2 +zmm21 = zmm12[1] zmm11[1] zmm12[5] zmm11[5] zmm10[1] zmm9 [1] zmm10[5] zmm9 [5] +zmm22 = zmm16[1] zmm15[1] zmm16[5] zmm15[5] zmm14[1] zmm13[1] zmm14[5] zmm13[5] + +Stage3 3,7 +zmm2 = zmm16[5] zmm15[5] zmm14[5] zmm13[5] zmm12[5] zmm11[5] zmm10[5] zmm9 [5] +zmm6 = zmm16[1] zmm15[1] zmm14[1] zmm13[1] zmm12[1] zmm11[1] zmm10[1] zmm9 [1] + +Stage1 +zmm17 = zmm10[0] zmm9 [0] zmm10[2] zmm9 [2] zmm10[4] zmm9 [4] zmm10[6] zmm9 [6] +zmm18 = zmm12[0] zmm11[0] zmm12[2] zmm11[2] zmm12[4] zmm11[4] zmm12[6] zmm11[6] +zmm19 = zmm14[0] zmm13[0] zmm14[2] zmm13[2] zmm14[4] zmm13[4] zmm14[6] zmm13[6] +zmm20 = zmm16[0] zmm15[0] zmm16[2] zmm15[2] zmm16[4] zmm15[4] zmm16[6] zmm15[6] + +Stage2 +zmm21 = zmm12[2] zmm11[2] zmm12[6] zmm11[6] zmm10[2] zmm9 [2] zmm10[6] zmm9 [6] +zmm22 = zmm16[2] zmm15[2] zmm16[6] zmm15[6] zmm14[2] zmm13[2] zmm14[6] zmm13[6] + +Stage3 2,6 +zmm1 = zmm16[6] zmm15[6] zmm14[6] zmm13[6] zmm12[6] zmm11[6] zmm10[6] zmm9 [6] +zmm5 = zmm16[2] zmm15[2] zmm14[2] zmm13[2] zmm12[2] zmm11[2] zmm10[2] zmm9 [2] + +Stage2 +zmm21 = zmm12[0] zmm11[0] zmm12[4] zmm11[4] zmm10[0] zmm9 [0] zmm10[4] zmm9 [4] +zmm22 = zmm16[0] zmm15[0] zmm16[4] zmm15[4] zmm14[0] zmm13[0] zmm14[4] zmm13[4] + +Stage3 4,8 +zmm3 = zmm16[4] zmm15[4] zmm14[4] zmm13[4] zmm12[4] zmm11[4] zmm10[4] zmm9 [4] +zmm7 = zmm16[0] zmm15[0] zmm14[0] zmm13[0] zmm12[0] zmm11[0] zmm10[0] zmm9 [0] +*/ +#define BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8_AND_STORE(b11, cs_b) \ + /*Stage1*/ \ + zmm17 = _mm512_unpacklo_pd(zmm9, zmm10); \ + zmm18 = _mm512_unpacklo_pd(zmm11, zmm12); \ + zmm19 = _mm512_unpacklo_pd(zmm13, zmm14); \ + zmm20 = _mm512_unpacklo_pd(zmm15, zmm16); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b10001000); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b10001000); \ + /*Stage3 1,5*/ \ + zmm0 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm4 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b11011101); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b11011101); \ + /*Stage3 3,7*/ \ + zmm2 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm6 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage1*/ \ + zmm17 = _mm512_unpackhi_pd(zmm9, zmm10); \ + zmm18 = _mm512_unpackhi_pd(zmm11, zmm12); \ + zmm19 = _mm512_unpackhi_pd(zmm13, zmm14); \ + zmm20 = _mm512_unpackhi_pd(zmm15, zmm16); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b10001000); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b10001000); \ + /*Stage3 2,6*/ \ + zmm1 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm5 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); \ + /*Stage2*/ \ + zmm21 = _mm512_shuffle_f64x2(zmm17, zmm18, 0b11011101); \ + zmm22 = _mm512_shuffle_f64x2(zmm19, zmm20, 0b11011101); \ + /*Stage3 4,8*/ \ + zmm3 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b10001000); \ + zmm7 = _mm512_shuffle_f64x2(zmm21, zmm22, 0b11011101); + +// endregion - 8x8 transpose for left variants + +// region - GEMM DTRSM for left variants + +#define BLIS_DTRSM_SMALL_GEMM_8mx8n_AVX512(a10, b01, cs_b, p_lda, k_iter, b11) \ + /*k_iter -= 8; */ \ + int itrCount = (k_iter / 2); \ + int itr = itrCount; \ + int itr2 = k_iter - itrCount; \ + double *b01_2 = b01 + itr; \ + double *a10_2 = a10 + (p_lda * itr); \ + for (; itr > 0; itr--) \ + { \ + zmm0 = _mm512_loadu_pd((double const *)a10); \ + \ + zmm1 = _mm512_set1_pd(*(b01 + cs_b * 0)); \ + zmm2 = _mm512_set1_pd(*(b01 + cs_b * 1)); \ + zmm3 = _mm512_set1_pd(*(b01 + cs_b * 2)); \ + zmm4 = _mm512_set1_pd(*(b01 + cs_b * 3)); \ + zmm5 = _mm512_set1_pd(*(b01 + cs_b * 4)); \ + zmm6 = _mm512_set1_pd(*(b01 + cs_b * 5)); \ + zmm7 = _mm512_set1_pd(*(b01 + cs_b * 6)); \ + zmm8 = _mm512_set1_pd(*(b01 + cs_b * 7)); \ + \ + _mm_prefetch((b01 + 8), _MM_HINT_T0); \ + zmm9 = _mm512_fmadd_pd(zmm1, zmm0, zmm9); \ + zmm10 = _mm512_fmadd_pd(zmm2, zmm0, zmm10); \ + zmm11 = _mm512_fmadd_pd(zmm3, zmm0, zmm11); \ + zmm12 = _mm512_fmadd_pd(zmm4, zmm0, zmm12); \ + zmm13 = _mm512_fmadd_pd(zmm5, zmm0, zmm13); \ + zmm14 = _mm512_fmadd_pd(zmm6, zmm0, zmm14); \ + zmm15 = _mm512_fmadd_pd(zmm7, zmm0, zmm15); \ + zmm16 = _mm512_fmadd_pd(zmm8, zmm0, zmm16); \ + \ + b01 += 1; \ + a10 += p_lda; \ + } \ + for (; itr2 > 0; itr2--) \ + { \ + zmm23 = _mm512_loadu_pd((double const *)a10_2); \ + \ + zmm17 = _mm512_set1_pd(*(b01_2 + cs_b * 0)); \ + zmm18 = _mm512_set1_pd(*(b01_2 + cs_b * 1)); \ + zmm19 = _mm512_set1_pd(*(b01_2 + cs_b * 2)); \ + zmm20 = _mm512_set1_pd(*(b01_2 + cs_b * 3)); \ + zmm21 = _mm512_set1_pd(*(b01_2 + cs_b * 4)); \ + zmm22 = _mm512_set1_pd(*(b01_2 + cs_b * 5)); \ + \ + _mm_prefetch((b01_2 + 8), _MM_HINT_T0); \ + zmm24 = _mm512_fmadd_pd(zmm17, zmm23, zmm24); \ + zmm17 = _mm512_set1_pd(*(b01_2 + cs_b * 6)); \ + zmm25 = _mm512_fmadd_pd(zmm18, zmm23, zmm25); \ + zmm18 = _mm512_set1_pd(*(b01_2 + cs_b * 7)); \ + zmm26 = _mm512_fmadd_pd(zmm19, zmm23, zmm26); \ + zmm27 = _mm512_fmadd_pd(zmm20, zmm23, zmm27); \ + zmm28 = _mm512_fmadd_pd(zmm21, zmm23, zmm28); \ + zmm29 = _mm512_fmadd_pd(zmm22, zmm23, zmm29); \ + zmm30 = _mm512_fmadd_pd(zmm17, zmm23, zmm30); \ + zmm31 = _mm512_fmadd_pd(zmm18, zmm23, zmm31); \ + \ + b01_2 += 1; \ + a10_2 += p_lda; \ + } \ + _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + zmm9 = _mm512_add_pd(zmm9, zmm24); \ + _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + zmm10 = _mm512_add_pd(zmm10, zmm25); \ + _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + zmm11 = _mm512_add_pd(zmm11, zmm26); \ + _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + zmm12 = _mm512_add_pd(zmm12, zmm27); \ + _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + zmm13 = _mm512_add_pd(zmm13, zmm28); \ + _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + zmm14 = _mm512_add_pd(zmm14, zmm29); \ + _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + zmm15 = _mm512_add_pd(zmm15, zmm30); \ + _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + zmm16 = _mm512_add_pd(zmm16, zmm31); + +#define BLIS_DTRSM_SMALL_GEMM_8mx4n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01)); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 2))); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 3))); \ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); \ + ymm15 = _mm256_fmadd_pd(ymm2, ymm1, ymm15); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx3n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 2))); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm14 = _mm256_fmadd_pd(ymm2, ymm1, ymm14); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx2n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + ymm13 = _mm256_fmadd_pd(ymm2, ymm1, ymm13); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_8mx1n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + ymm1 = _mm256_loadu_pd((double const *)(a10 + 4)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + ymm12 = _mm256_fmadd_pd(ymm2, ymm1, ymm12); \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm1 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 0))); \ + ymm2 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 1))); \ + ymm3 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 2))); \ + ymm4 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 3))); \ + ymm5 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 4))); \ + ymm6 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 5))); \ + ymm7 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 6))); \ + ymm8 = _mm256_broadcast_sd((double const*)(b01 + (cs_b * 7))); \ + \ + _mm_prefetch((b01 + 4 * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_fmadd_pd (ymm1, ymm0, ymm9); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ + ymm12 = _mm256_fmadd_pd(ymm4, ymm0, ymm12); \ + ymm13 = _mm256_fmadd_pd(ymm5, ymm0, ymm13); \ + ymm14 = _mm256_fmadd_pd(ymm6, ymm0, ymm14); \ + ymm15 = _mm256_fmadd_pd(ymm7, ymm0, ymm15); \ + ymm16 = _mm256_fmadd_pd(ymm8, ymm0, ymm16); \ + \ + b01 += 1; \ + a10 += p_lda; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 2))); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 3))); \ + ymm11 = _mm256_fmadd_pd(ymm2, ymm0, ymm11); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 2))); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 1))); \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(a10)); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(b01 + (cs_b * 0))); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm0, ymm8); \ + \ + b01 += 1; /*move to next row of B*/ \ + a10 += p_lda; /*pointer math to calculate next block of A for GEMM*/ \ + } + +// endregion - GEMM DTRSM for left variants + +// region - pre/post DTRSM for left variants + +#define BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0) + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); \ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 1) + 2)); \ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); \ + ymm2 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 2) + 2)); \ + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); \ + _mm_storel_pd((b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm8, 1)); \ + xmm5 = _mm256_castpd256_pd128(ymm9); \ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); \ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); \ + xmm5 = _mm256_castpd256_pd128(ymm10); \ + _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); \ + _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0) + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); \ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 1) + 2)); \ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + cs_b * 0), xmm5); \ + _mm_storel_pd((b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm8, 1)); \ + xmm5 = _mm256_castpd256_pd128(ymm9); \ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); \ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0) + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11), xmm5); \ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm8, 1)); + +#define BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 0))); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 1))); \ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 2))); \ + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 0)), xmm5); \ + xmm5 = _mm256_castpd256_pd128(ymm9); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 1)), xmm5); \ + xmm5 = _mm256_castpd256_pd128(ymm10); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 2)), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 0))); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 1))); \ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 0)), xmm5); \ + xmm5 = _mm256_castpd256_pd128(ymm9); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 1)), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 0))); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + \ + xmm5 = _mm256_castpd256_pd128(ymm8); \ + _mm_storeu_pd((double *)(b11 + (cs_b * 0)), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0))); \ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 1))); \ + ymm2 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 2))); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); \ + \ + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0)); \ + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9, 0)); \ + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 0))); \ + ymm1 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 1))); \ + \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); \ + \ + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0)); \ + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9, 0)); + +#define BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal, b11, cs_b) \ + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); /*register to hold alpha*/ \ + \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0)); \ + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); \ + \ + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0)); + // LLNN - LUTN BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 ( @@ -6508,9 +7227,1956 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 cntl_t* cntl ) { - return BLIS_NOT_YET_IMPLEMENTED; + + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); // number of columns + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8, d_nr = 8; + + // Swap rs_a & cs_a in case of non-tranpose. + if (transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; + dim_t k_iter; + + double AlphaVal = *(double *)AlphaObj->buffer; + double *L = bli_obj_buffer_at_off(a); // pointer to matrix A + double *B = bli_obj_buffer_at_off(b); // pointer to matrix B + + double *a10, *a11, *b01, *b11; // pointers for GEMM and TRSM blocks + + + double ones = 1.0; + + const gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; // pointer to A01 pack buffer + double d11_pack[d_mr] __attribute__((aligned(64))); // buffer for diagonal A pack + rntm_t rntm; + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ((d_mr * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if (FALSE == bli_mem_is_alloc(&local_mem_buf_A_s)) + return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if (NULL == D_A_pack) + return BLIS_NULL_POINTER; + } + bool is_unitdiag = bli_obj_has_unit_diag(a); + __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7, zmm8, zmm9, zmm10, zmm11; + __m512d zmm12, zmm13, zmm14, zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512d zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15, ymm16; + __m128d xmm5; + xmm5 = _mm_setzero_pd(); + + /* + Performs solving TRSM for 8 columns at a time from 0 to m/8 in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x6 to 8x (m-8) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + for (i = 0; (i + d_mr - 1) < m; i += d_mr) + { + a10 = L + (i * cs_a); + a11 = L + (i * rs_a) + (i * cs_a); + + dim_t p_lda = d_mr; + + if (transa) + { + /* + Pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next iteration + until it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack_avx512('L', i, 1, a10, cs_a, D_A_pack, p_lda, d_mr); + /* + Pack 8 diagonal elements of A block into an array + a. This helps to utilize cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element_avx512(is_unitdiag, a11, cs_a, d11_pack, d_mr); + } + else + { + bli_dtrsm_small_pack_avx512('L', i, 0, a10, rs_a, D_A_pack, p_lda, d_mr); + dtrsm_small_pack_diag_element_avx512(is_unitdiag, a11, rs_a, d11_pack, d_mr); + } + + /* + a. Perform GEMM using a10, b01. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x8 block size + along n dimension for every d_nr columns of B01 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + + for (j = 0; j < n - d_nr + 1; j += d_nr) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); //pointer to block of A to be used for TRSM + b01 = B + j * cs_b; //pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; //pointer to block of B to be used for TRSM + k_iter = i; + + BLIS_SET_ZMM_REG_ZEROS + /* + Perform GEMM between a10 and b01 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8mx8n_AVX512(a10, b01, cs_b, p_lda, k_iter, b11) + /* + Load b11 of size 8x8 and multiply with alpha + Add the GEMM output and perform in register transpose of b11 + to perform TRSM operation. + */ + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8(b11, cs_b, AlphaVal) + + /* + Compute 8x8 TRSM block by using GEMM block output in register + a. The 8x8 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : zmm9 zmm10 zmm11 zmm12 zmm13 zmm14 zmm15 zmm16 + b. Towards the end TRSM output will be stored back into b11 + */ + // extract a00 + zmm0 = _mm512_set1_pd(*(d11_pack + 0)); + zmm9 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm9, zmm0); + + // extract a11 + zmm1 = _mm512_set1_pd(*(d11_pack + 1)); + zmm2 = _mm512_set1_pd(*(a11 + (1 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm2, zmm9, zmm10); + zmm3 = _mm512_set1_pd(*(a11 + (2 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm3, zmm9, zmm11); + zmm4 = _mm512_set1_pd(*(a11 + (3 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm4, zmm9, zmm12); + zmm5 = _mm512_set1_pd(*(a11 + (4 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm5, zmm9, zmm13); + zmm6 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm6, zmm9, zmm14); + zmm7 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm7, zmm9, zmm15); + zmm8 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm8, zmm9, zmm16); + zmm10 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm10, zmm1); + a11 += rs_a; + + // extract a22 + zmm0 = _mm512_set1_pd(*(d11_pack + 2)); + zmm2 = _mm512_set1_pd(*(a11 + (2 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm2, zmm10, zmm11); + zmm3 = _mm512_set1_pd(*(a11 + (3 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm3, zmm10, zmm12); + zmm4 = _mm512_set1_pd(*(a11 + (4 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm4, zmm10, zmm13); + zmm5 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm5, zmm10, zmm14); + zmm6 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm6, zmm10, zmm15); + zmm7 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm7, zmm10, zmm16); + zmm11 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm11, zmm0); + a11 += rs_a; + + // extract a33 + zmm1 = _mm512_set1_pd(*(d11_pack + 3)); + zmm2 = _mm512_set1_pd(*(a11 + (3 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm2, zmm11, zmm12); + zmm3 = _mm512_set1_pd(*(a11 + (4 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm3, zmm11, zmm13); + zmm4 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm4, zmm11, zmm14); + zmm5 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm5, zmm11, zmm15); + zmm6 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm6, zmm11, zmm16); + zmm12 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm12, zmm1); + a11 += rs_a; + + // extract a44 + zmm0 = _mm512_set1_pd(*(d11_pack + 4)); + zmm2 = _mm512_set1_pd(*(a11 + (4 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm2, zmm12, zmm13); + zmm3 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm3, zmm12, zmm14); + zmm4 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm4, zmm12, zmm15); + zmm5 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm5, zmm12, zmm16); + zmm13 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm13, zmm0); + a11 += rs_a; + + // extract a55 + zmm1 = _mm512_set1_pd(*(d11_pack + 5)); + zmm2 = _mm512_set1_pd(*(a11 + (5 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm2, zmm13, zmm14); + zmm3 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm3, zmm13, zmm15); + zmm4 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm4, zmm13, zmm16); + zmm14 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm14, zmm1); + a11 += rs_a; + + // extract a66 + zmm0 = _mm512_set1_pd(*(d11_pack + 6)); + zmm2 = _mm512_set1_pd(*(a11 + (6 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm2, zmm14, zmm15); + zmm3 = _mm512_set1_pd(*(a11 + (7 * cs_a))); + zmm16 = _mm512_fnmadd_pd(zmm3, zmm14, zmm16); + zmm15 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm15, zmm0); + a11 += rs_a; + + // extract a77 + zmm1 = _mm512_set1_pd(*(d11_pack + 7)); + zmm2 = _mm512_set1_pd(*(a11 + 7 * cs_a)); + zmm16 = _mm512_fnmadd_pd(zmm2, zmm15, zmm16); + zmm16 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm16, zmm1); + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8_AND_STORE(b11, cs_b) + _mm512_storeu_pd((double *)(b11 + (cs_b * 0)), zmm0); + _mm512_storeu_pd((double *)(b11 + (cs_b * 1)), zmm1); + _mm512_storeu_pd((double *)(b11 + (cs_b * 2)), zmm2); + _mm512_storeu_pd((double *)(b11 + (cs_b * 3)), zmm3); + _mm512_storeu_pd((double *)(b11 + (cs_b * 4)), zmm4); + _mm512_storeu_pd((double *)(b11 + (cs_b * 5)), zmm5); + _mm512_storeu_pd((double *)(b11 + (cs_b * 6)), zmm6); + _mm512_storeu_pd((double *)(b11 + (cs_b * 7)), zmm7); + } + dim_t n_rem = n - j; + if (n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 0))); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 1))); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 3))); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm4 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 0) + 4)); + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 1) + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2) + 4)); + // B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 3) + 4)); + // B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); // B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); // B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); // B11[0-3][7] * alpha -= B01[0-3][7] + + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); // B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); // B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13, ymm15, 0x20); // B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13, ymm15, 0x31); // B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); // B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); // B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); // B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); // B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1))); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2))); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2))); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + // perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + // perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + // perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + // extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 7)); + + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + // perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); + // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); + // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + // B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); + // B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); + // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); + // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); + // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); + // B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); + // B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); + // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); + // B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); // store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); // store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); // store B11[7][0-3] + + n_rem -= 4; + j += 4; + } + + if (n_rem) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + j * cs_b; // pointer to block of B to be used for GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); + // B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); + // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); + // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); + // B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); // B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); // B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13, ymm15, 0x20); // B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13, ymm15, 0x31); // B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); // B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); // B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); // B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); // B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1))); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2))); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm8, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm8, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm8, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm8, ymm15); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2))); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm9, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm9, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm9, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm9, ymm15); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3))); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm12 = _mm256_fnmadd_pd(ymm5, ymm10, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm10, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm10, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm10, ymm15); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a44 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 4))); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + //(ROw4): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm5, ymm11, ymm12); + ymm13 = _mm256_fnmadd_pd(ymm6, ymm11, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm11, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm11, ymm15); + + // perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 5))); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 7)); + + a11 += rs_a; + + // extract a55 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + //(ROw5): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm6, ymm12, ymm13); + ymm14 = _mm256_fnmadd_pd(ymm7, ymm12, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm12, ymm15); + + // perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 6))); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 7))); + + a11 += rs_a; + + // extract a66 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + //(ROw6): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm7, ymm13, ymm14); + ymm15 = _mm256_fnmadd_pd(ymm16, ymm13, ymm15); + + // perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + // extract a77 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 7)); + + a11 += rs_a; + //(ROw7): FMA operations + ymm15 = _mm256_fnmadd_pd(ymm16, ymm14, ymm15); + + // perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); // B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); // B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); // B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); // B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); // B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if (3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); // store B11[6][0-3] + } + else if (2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + } + else if (1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + } + } + } + + //======================M remainder cases================================ + dim_t m_rem = m - i; + + if (m_rem >= 4) // implementation for reamainder rows(when 'M' is not a multiple of d_mr) + { + a10 = L + (i * cs_a); // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + + if (transa) + { + for (dim_t x = 0; x < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < i; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3 + 3)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + rs_a * 1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + rs_a * 2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + rs_a * 3 + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (j = 0; (j + d_nr - 1) < n; j += d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM operation to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter); + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8(b11, cs_b, AlphaVal); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 0)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 * cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm3, ymm9, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm3, ymm13, ymm15); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm9, ymm12); + ymm16 = _mm256_fnmadd_pd(ymm4, ymm13, ymm16); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + a11 += rs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm3, ymm10, ymm12); + ymm16 = _mm256_fnmadd_pd(ymm3, ymm14, ymm16); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm0); + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a)); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm16 = _mm256_fnmadd_pd(ymm2, ymm15, ymm16); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm1); + + a11 += rs_a; + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8_AND_STORE(b11, cs_b) + + _mm256_storeu_pd((double *)(b11 + 0 * cs_b), ymm0); + _mm256_storeu_pd((double *)(b11 + 1 * cs_b), ymm1); + _mm256_storeu_pd((double *)(b11 + 2 * cs_b), ymm2); + _mm256_storeu_pd((double *)(b11 + 3 * cs_b), ymm3); + _mm256_storeu_pd((double *)(b11 + 4 * cs_b), ymm4); + _mm256_storeu_pd((double *)(b11 + 5 * cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + 6 * cs_b), ymm6); + _mm256_storeu_pd((double *)(b11 + 7 * cs_b), ymm7); + } + + dim_t n_rem = n - j; + if (n_rem >= 4) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + j * cs_b; // pointer to block of B to be used for GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm2,ymm12,ymm13); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm3,ymm12,ymm14); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); + + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm3, ymm13, ymm14); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm13, ymm15); + + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm14, ymm15); + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); // store B11[3][0-3] + + n_rem -= 4; + j += 4; + } + if (n_rem) + { + a10 = D_A_pack; + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + j * cs_b; // pointer to block of B to be used for GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm0 = _mm256_broadcast_sd((double const *)&ones); + + ////extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 1)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + //(ROw1): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm2, ymm8, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm10 = _mm256_fnmadd_pd(ymm3, ymm8, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm8, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw2): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm3, ymm9, ymm10); + ymm11 = _mm256_fnmadd_pd(ymm4, ymm9, ymm11); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + cs_a * 3)); + + a11 += rs_a; + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(ROw5): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm4, ymm10, ymm11); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if (3 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + } + else if (2 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + } + else if (1 == n_rem) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + } + } + m_rem -= 4; + i += 4; + } + + if (m_rem) + { + a10 = L + (i * cs_a); // pointer to block of A to be used for GEMM + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if (3 == m_rem) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < i; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + + // cols + for (j = 0; (j + d_nr - 1) < n; j += d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + + ymm0 = _mm256_broadcast_sd((double const *)(&AlphaVal)); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm4 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3 + 2)); + ymm4 = _mm256_insertf128_pd(ymm4, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 4)); + ymm5 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 4 + 2)); + ymm5 = _mm256_insertf128_pd(ymm5, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 5)); + ymm6 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 5 + 2)); + ymm6 = _mm256_insertf128_pd(ymm6, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 6)); + ymm7 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 6 + 2)); + ymm7 = _mm256_insertf128_pd(ymm7, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 7)); + ymm8 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 7 + 2)); + ymm8 = _mm256_insertf128_pd(ymm8, xmm5, 0); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storeu_pd((double *)(b11 + cs_b * 0), _mm256_extractf128_pd(ymm9, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm10, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm11, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm12, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm13, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm14, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 6), _mm256_extractf128_pd(ymm15, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 7), _mm256_extractf128_pd(ymm16, 0)); + + _mm_storel_pd((double *)(b11 + cs_b * 0 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm10, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm11, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm12, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 4 + 2), _mm256_extractf128_pd(ymm13, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 5 + 2), _mm256_extractf128_pd(ymm14, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 6 + 2), _mm256_extractf128_pd(ymm15, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 7 + 2), _mm256_extractf128_pd(ymm16, 1)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 8, rs_a, cs_b, is_unitdiag); + } + + dim_t n_rem = n - j; + if ((n_rem >= 4)) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11, 1)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j += 4; + } + + if (n_rem) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if (2 == m_rem) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < i; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + // cols + for (j = 0; (j + d_nr - 1) < n; j += d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + ymm0 = _mm256_broadcast_sd((double const *)(&AlphaVal)); + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm4 = _mm256_insertf128_pd(ymm4, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 4)); + ymm5 = _mm256_insertf128_pd(ymm5, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 5)); + ymm6 = _mm256_insertf128_pd(ymm6, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 6)); + ymm7 = _mm256_insertf128_pd(ymm7, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 7)); + ymm8 = _mm256_insertf128_pd(ymm8, xmm5, 0); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storeu_pd((double *)(b11 + cs_b * 0), _mm256_extractf128_pd(ymm9, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm10, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm11, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm12, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm13, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm14, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 6), _mm256_extractf128_pd(ymm15, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 7), _mm256_extractf128_pd(ymm16, 0)); + + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 8, rs_a, cs_b, is_unitdiag); + } + + dim_t n_rem = n - j; + if ((n_rem >= 4)) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT +; + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j += 4; + } + if (n_rem) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -= 2; + i += 2; + } + else if (1 == m_rem) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < i; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < i; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + rs_a * x)); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * x), ymm0); + } + } + // cols + for (j = 0; (j + d_nr - 1) < n; j += d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter); + + /// GEMM code ends/// + ymm0 = _mm256_broadcast_sd((double const*)(&AlphaVal)); + ymm1 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 0)); + ymm2 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 1)); + ymm3 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 2)); + ymm4 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 3)); + ymm5 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 4)); + ymm6 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 5)); + ymm7 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 6)); + ymm8 = _mm256_broadcast_sd((double const*)(b11 + cs_b * 7)); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storel_pd((double *)(b11 + cs_b * 0), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm10, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm11, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm12, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 4), _mm256_extractf128_pd(ymm13, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 5), _mm256_extractf128_pd(ymm14, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 6), _mm256_extractf128_pd(ymm15, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 7), _mm256_extractf128_pd(ymm16, 0)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 8, rs_a, cs_b, is_unitdiag); + } + dim_t n_rem = n - j; + if ((n_rem >= 4)) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storel_pd((b11 + cs_b * 0), _mm256_extractf128_pd(ymm8, 0)); + _mm_storel_pd((b11 + cs_b * 1), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((b11 + cs_b * 2), _mm256_extractf128_pd(ymm10, 0)); + _mm_storel_pd((b11 + cs_b * 3), _mm256_extractf128_pd(ymm11, 0)); + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AlXB_ref(a11, b11, m_rem, 4, rs_a, cs_b, is_unitdiag); + n_rem -= 4; + j += 4; + } + + if (n_rem) + { + a10 = D_A_pack; // pointer to block of A to be used for GEMM + a11 = L + (i * rs_a) + (i * cs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b); // pointer to block of B to be used for GEMM + b11 = B + i + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = i; // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + if (3 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AlXB_ref(a11, b11, m_rem, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_rem) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AutXB_ref(a11, b11, m_rem, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AutXB_ref(a11, b11, m_rem, 1, rs_a, cs_b, is_unitdiag); + } + } + m_rem -= 1; + i += 1; + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc(&local_mem_buf_A_s)) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + return BLIS_SUCCESS; } + // LUNN LUTN BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512 (