From b84157fed646cd7a252d8e56582d02a61ff65f1f Mon Sep 17 00:00:00 2001 From: Shubham Date: Tue, 17 Jan 2023 23:04:27 +0530 Subject: [PATCH] Added AVX512 DTRSM small RLNN/RUTN variant kernels - 8x8 kernels are used for DTRSM SMALL - Implemented fringe cases with below block sizes 8x8, 8x4, 8x3, 8x2, 8x1 4x8, 4x4, 4x3, 4x2, 4x1 3x8, 3x4, 3x3, 3x2, 3x1 2x8, 2x4, 2x3, 2x2, 2x1 1x8, 1x4, 1x3, 1x2, 1x1 AMD-Internal: [CPUPL-2745] Change-Id: Ifb8cfba6958e1c89ddbfa18893127ab6d44cc367 --- frame/compat/bla_trsm_amd.c | 10 +- kernels/zen4/3/bli_trsm_small_AVX512.c | 2416 ++++++++++++++++++++++-- 2 files changed, 2309 insertions(+), 117 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index b5f295854..680e3c1bf 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -956,9 +956,6 @@ void dtrsm_blis_impl // Query the architecture ID arch_t id = bli_arch_query_id(); -#if defined(BLIS_KERNELS_ZEN4) - bool uplo, transa; -#endif switch(id) { case BLIS_ARCH_ZEN4: @@ -969,11 +966,8 @@ void dtrsm_blis_impl // for n < 200 avx2 kernels are performing better, but if // n is a multiple of 8 then there will be no fringe case for avx512, // in such cases avx512 kernels will perform better. - uplo = bli_obj_is_upper(&ao); - transa = bli_obj_has_trans(&ao); - if(( ((blis_side == BLIS_RIGHT) && (uplo == true) && (transa == false)) || - ((blis_side == BLIS_RIGHT) && (uplo == false) && (transa == true))) && - ((n0 > 400) && (m0 > 50))) + if( (blis_side == BLIS_RIGHT) && + ((n0 > 300) && (m0 > 50))) { status = bli_trsm_small_AVX512( blis_side, diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c index 9912c8e4a..0a70ef5f0 100644 --- a/kernels/zen4/3/bli_trsm_small_AVX512.c +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -120,6 +120,17 @@ zmm30 = _mm512_setzero_pd(); \ zmm31 = _mm512_setzero_pd(); +#define BLIS_SET_YMM_REG_ZEROS_FOR_N_REM \ + ymm3 = _mm256_setzero_pd(); \ + ymm4 = _mm256_setzero_pd(); \ + ymm5 = _mm256_setzero_pd(); \ + ymm6 = _mm256_setzero_pd(); \ + ymm7 = _mm256_setzero_pd(); \ + ymm8 = _mm256_setzero_pd(); \ + ymm9 = _mm256_setzero_pd(); \ + ymm10 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + /* declaration of trsm small kernels function pointer */ @@ -1067,7 +1078,7 @@ err_t bli_trsm_small_mt_AVX512 } \ for (; itr2 > 0; itr2--) \ { \ - ymm23 = _mm256_broadcast_sd(b10); \ + ymm23 = _mm256_broadcast_sd(b10_2); \ \ ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 0))); \ ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 1))); \ @@ -2731,8 +2742,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if (transa) { - dim_t x = 0; - for (x = 0; x < p_lda; x += 1) + for (dim_t x = 0; x < p_lda; x += 1) { bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); @@ -2817,15 +2827,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx8m(a01, b10, cs_b, p_lda, k_iter) @@ -2916,15 +2918,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx4m(a01, b10, cs_b, p_lda, k_iter) @@ -3016,15 +3010,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx3m(a01, b10, cs_b, p_lda, k_iter) @@ -3121,15 +3107,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx2m(a01, b10, cs_b, p_lda, k_iter) @@ -3220,15 +3198,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx1m(a01, b10, cs_b, p_lda, k_iter) @@ -3325,8 +3295,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if (transa) { - dim_t x = 0; - for (x = 0; x < p_lda; x += 1) + for (dim_t x = 0; x < p_lda; x += 1) { bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + p_lda * 0)); bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + p_lda * 1)); @@ -3405,15 +3374,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx8m(a01, b10, cs_b, p_lda, k_iter) @@ -3508,15 +3469,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx4m(a01, b10, cs_b, p_lda, k_iter) @@ -3584,15 +3537,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx3m(a01, b10, cs_b, p_lda, k_iter) @@ -3643,15 +3588,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx2m(a01, b10, cs_b, p_lda, k_iter) @@ -3702,15 +3639,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx1m(a01, b10, cs_b, p_lda, k_iter) @@ -3766,8 +3695,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if (transa) { - dim_t x = 0; - for (x = 0; x < p_lda; x += 1) + for (dim_t x = 0; x < p_lda; x += 1) { bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); @@ -3840,15 +3768,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_2nx8m(a01, b10, cs_b, p_lda, k_iter) @@ -4077,8 +3997,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if (transa) { - dim_t x = 0; - for (x = 0; x < p_lda; x += 1) + for (dim_t x = 0; x < p_lda; x += 1) { bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + p_lda * 0)); ptr_a10_dup += 1; @@ -4301,7 +4220,2286 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB_AVX512 cntl_t* cntl ) { - return BLIS_NOT_YET_IMPLEMENTED; + dim_t m = bli_obj_length(b); // number of rows + dim_t n = bli_obj_width(b); // number of columns + dim_t d_mr = 8, d_nr = 8; + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + double ones = 1.0; + + // Swap rs_a & cs_a in case of non-transpose. + if (transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; + dim_t k_iter; + + bool is_unitdiag = bli_obj_has_unit_diag(a); + + double AlphaVal = *(double *)AlphaObj->buffer; + double *restrict L = bli_obj_buffer_at_off(a); // pointer to matrix A + double *B = bli_obj_buffer_at_off(b); // pointer to matrix B + + double *a01, *a11, *b10, *b11; // pointers for GEMM and TRSM blocks + + bool required_packing_A = true; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; // pointer to A01 pack buffer + double d11_pack[d_mr] __attribute__((aligned(64))); // buffer for diagonal A pack + rntm_t rntm; + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ((d_nr * n * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); // acquire memory for A01 pack + if (FALSE == bli_mem_is_alloc(&local_mem_buf_A_s)) + return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if (NULL == D_A_pack) + return BLIS_NULL_POINTER; + } + __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7, zmm8, zmm9, zmm10, zmm11; + __m512d zmm12, zmm13, zmm14, zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512d zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15, ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm22, ymm23, ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31; + __m128d xmm5, xmm0; + + /* + Performs solving TRSM for 8 rows at a time from 0 to n/8 in steps of d_nr + a. Load and pack A (a01 block), the size of packing 8x8 to 8x(n-8) + First there will be no GEMM and no packing of a01 because it is only TRSM + b. Using packed a01 block and b10 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B + d. Repeat b for m cols of B in steps of d_mr + */ + for (j = (n - d_nr); j > -1; j -= d_nr) //loop along 'N' direction + { + a01 = L + (j * rs_a) + (j + d_nr) * cs_a; //pointer to block of A to be used in GEMM + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + + dim_t p_lda = (n - j - d_nr); //packed leading dimension + + // perform copy of A to packed buffer D_A_pack + if (transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 8x(n-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_dtrsm_small_pack_avx512 + ( + 'R', + p_lda, + 1, + a01, + cs_a, + D_A_pack, + p_lda, + d_nr + ); + /* + Pack 8 diagonal elements of A block into an array + a. This helps to utilize cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element_avx512 + ( + is_unitdiag, + a11, + cs_a, + d11_pack, + d_nr + ); + } + else + { + bli_dtrsm_small_pack_avx512 + ( + 'R', + p_lda, + 0, + a01, + rs_a, + D_A_pack, + p_lda, + d_nr + ); + dtrsm_small_pack_diag_element_avx512 + ( + is_unitdiag, + a11, + rs_a, + d11_pack, + d_nr + ); + } + + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (j + d_nr) * cs_b; //pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_ZMM_REG_ZEROS + /* + Perform GEMM between a01 and b10 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8nx8m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11); + /* + Load b11 of size 8x8 and multiply with alpha + Add the GEMM output to b11 + and perform TRSM operation. + */ + BLIS_PRE_DTRSM_SMALL_8x8(AlphaVal, b11, cs_b) + + + + /* + Compute 8x8 TRSM block by using GEMM block output in register + a. The 8x8 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : zmm9 zmm10 zmm11 zmm12 zmm13 zmm14 zmm15 zmm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + * zmm11 -= ((zmm12 * A11[2][3]) + (zmm13 * A11[2][4]) + + * (zmm14 * A11[2][5]) + (zmm15 * A11[2][6]) + + * (zmm16 * A11[2][7])) / A11[2][2] + */ + + // extract a77 + zmm0 = _mm512_set1_pd(*(d11_pack + 7)); + zmm16 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm16, zmm0); + _mm512_storeu_pd((double *)(b11 + 7 * cs_b), zmm16); + + // extract a66 + zmm0 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (6 * rs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (5 * rs_a))); + zmm15 = _mm512_fnmadd_pd(zmm0, zmm16, zmm15); + zmm0 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (4 * rs_a))); + zmm14 = _mm512_fnmadd_pd(zmm1, zmm16, zmm14); + zmm1 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (3 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm16, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (2 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm16, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm16, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm16, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 6)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm16, zmm9); + zmm15 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm15, zmm1); + _mm512_storeu_pd((double *)(b11 + (6 * cs_b)), zmm15); + + // extract a55 + zmm1 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (5 * rs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (4 * rs_a))); + zmm14 = _mm512_fnmadd_pd(zmm1, zmm15, zmm14); + zmm1 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (3 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm15, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (2 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm15, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm15, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm15, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 5)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm15, zmm9); + zmm14 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm14, zmm1); + _mm512_storeu_pd((double *)(b11 + (5 * cs_b)), zmm14); + + // extract a44 + zmm0 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (4 * rs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (3 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm14, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (2 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm14, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm14, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm14, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 4)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm14, zmm9); + zmm13 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm13, zmm1); + _mm512_storeu_pd((double *)(b11 + (4 * cs_b)), zmm13); + + // extract a33 + zmm1 = _mm512_set1_pd(*(a11 + (4 * cs_a) + (3 * rs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (4 * cs_a) + (2 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm13, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (4 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm13, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (4 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm13, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 3)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm13, zmm9); + zmm12 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm12, zmm1); + _mm512_storeu_pd((double *)(b11 + (3 * cs_b)), zmm12); + + // extract a22 + zmm0 = _mm512_set1_pd(*(a11 + (3 * cs_a) + (2 * rs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (3 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm12, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (3 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm12, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 2)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm12, zmm9); + zmm11 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm11, zmm1); + _mm512_storeu_pd((double *)(b11 + (2 * cs_b)), zmm11); + + // extract a11 + zmm1 = _mm512_set1_pd(*(a11 + (2 * cs_a) + (1 * rs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (2 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm11, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 1)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm11, zmm9); + zmm10 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm10, zmm1); + _mm512_storeu_pd((double *)(b11 + (1 * cs_b)), zmm10); + + // extract a00 + zmm1 = _mm512_set1_pd(*(a11 + (1 * cs_a) + (0 * rs_a))); + zmm0 = _mm512_set1_pd(*(d11_pack + 0)); + zmm9 = _mm512_fnmadd_pd(zmm1, zmm10, zmm9); + zmm9 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm9, zmm0); + _mm512_storeu_pd((double *)(b11 + (0 * cs_b)), zmm9); + } + dim_t m_remainder = i + d_mr; + if(m_remainder) + { + if (m_remainder >= 4) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (j + d_nr) * cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (j * cs_b); //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_YMM_REG_ZEROS_AVX512 + /* + Perform GEMM between a01 and b10 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8nx4m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + /* + Load b11 of size 8x4 and multiply with alpha + Add the GEMM output to b11 + and perform TRSM operation. + */ + BLIS_PRE_DTRSM_SMALL_8x4(AlphaVal, b11, cs_b) + + /* + Compute 8x4 TRSM block by using GEMM block output in register + a. The 8x4 input (gemm outputs) are stored in combinations of ymm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + */ + + // extract a77 + ymm0 = _mm256_broadcast_sd((d11_pack + 7)); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm0); + _mm256_storeu_pd((double *)(b11 + (7 * cs_b)), ymm16); + + // extract a66 + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (6 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (5 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm16, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm16, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm16, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 6)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm16, ymm9); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + _mm256_storeu_pd((double *)(b11 + (6 * cs_b)), ymm15); + + // extract a55 + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (5 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm15, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm15, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm15, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm15, ymm9); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + _mm256_storeu_pd((double *)(b11 + (5 * cs_b)), ymm14); + + // extract a44 + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (4 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm14, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm14, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 4)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm14, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + _mm256_storeu_pd((double *)(b11 + (4 * cs_b)), ymm13); + + // extract a33 + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (3 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm13, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm13, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm13, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm13, ymm9); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + _mm256_storeu_pd((double *)(b11 + (3 * cs_b)), ymm12); + + // extract a22 + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (2 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + _mm256_storeu_pd((double *)(b11 + (2 * cs_b)), ymm11); + + // extract a11 + ymm1 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (1 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + _mm256_storeu_pd((double *)(b11 + (1 * cs_b)), ymm10); + + // extract a00 + ymm1 = _mm256_broadcast_sd((a11 + (1 * cs_a) + (0 * rs_a))); + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm256_storeu_pd((double *)(b11 + (0 * cs_b)), ymm9); + m_remainder -= 4; + } + if (m_remainder == 3) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + b10 = B + (j + d_nr) * cs_b; // pointer to block of B to be used in GEMM + b11 = B + (j * cs_b); //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_YMM_REG_ZEROS_AVX512 + /* + Perform GEMM between a01 and b10 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8nx3m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + /* + Load b11 of size 8x3 and multiply with alpha + Add the GEMM output to b11 + and perform TRSM operation. + */ + BLIS_PRE_DTRSM_SMALL_8x3(AlphaVal, b11, cs_b) + /* + Compute 8x3 TRSM block by using GEMM block output in register + a. The 8x3 input (gemm outputs) are stored in combinations of ymm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + */ + + // extract a77 + ymm0 = _mm256_broadcast_sd((d11_pack + 7)); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm0); + _mm_storeu_pd((double *)(b11 + (7 * cs_b) + 0), _mm256_castpd256_pd128(ymm16)); + _mm_storel_pd((double *)(b11 + (7 * cs_b) + 2), _mm256_extractf64x2_pd(ymm16, 1)); + + // extract a66 + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (6 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (5 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm16, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm16, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm16, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 6)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm16, ymm9); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + _mm_storeu_pd((double *)(b11 + (6 * cs_b) + 0), _mm256_castpd256_pd128(ymm15)); + _mm_storel_pd((double *)(b11 + (6 * cs_b) + 2), _mm256_extractf64x2_pd(ymm15, 1)); + + // extract a55 + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (5 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm15, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm15, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm15, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm15, ymm9); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + _mm_storeu_pd((double *)(b11 + (5 * cs_b) + 0), _mm256_castpd256_pd128(ymm14)); + _mm_storel_pd((double *)(b11 + (5 * cs_b) + 2), _mm256_extractf64x2_pd(ymm14, 1)); + + // extract a44 + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (4 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm14, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm14, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 4)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm14, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + _mm_storeu_pd((double *)(b11 + (4 * cs_b) + 0), _mm256_castpd256_pd128(ymm13)); + _mm_storel_pd((double *)(b11 + (4 * cs_b) + 2), _mm256_extractf64x2_pd(ymm13, 1)); + + // extract a33 + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (3 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm13, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm13, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm13, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm13, ymm9); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + _mm_storeu_pd((double *)(b11 + (3 * cs_b) + 0), _mm256_castpd256_pd128(ymm12)); + _mm_storel_pd((double *)(b11 + (3 * cs_b) + 2), _mm256_extractf64x2_pd(ymm12, 1)); + + // extract a22 + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (2 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + _mm_storeu_pd((double *)(b11 + (2 * cs_b) + 0), _mm256_castpd256_pd128(ymm11)); + _mm_storel_pd((double *)(b11 + (2 * cs_b) + 2), _mm256_extractf64x2_pd(ymm11, 1)); + + // extract a11 + ymm1 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (1 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + _mm_storeu_pd((double *)(b11 + (1 * cs_b) + 0), _mm256_castpd256_pd128(ymm10)); + _mm_storel_pd((double *)(b11 + (1 * cs_b) + 2), _mm256_extractf64x2_pd(ymm10, 1)); + + // extract a00 + ymm1 = _mm256_broadcast_sd((a11 + (1 * cs_a) + (0 * rs_a))); + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm_storeu_pd((double *)(b11 + (0 * cs_b) + 0), _mm256_castpd256_pd128(ymm9)); + _mm_storel_pd((double *)(b11 + (0 * cs_b) + 2), _mm256_extractf64x2_pd(ymm9, 1)); + m_remainder -= 3; + } + else if (m_remainder == 2) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + b10 = B + (j + d_nr) * cs_b; // pointer to block of B to be used in GEMM + b11 = B + (j * cs_b); //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_YMM_REG_ZEROS_AVX512 + BLIS_DTRSM_SMALL_GEMM_8nx2m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + BLIS_PRE_DTRSM_SMALL_8x2(AlphaVal, b11, cs_b) + /* + Compute 8x2 TRSM block by using GEMM block output in register + a. The 8x2 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + */ + + // extract a77 + ymm0 = _mm256_broadcast_sd((d11_pack + 7)); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm0); + _mm_storeu_pd((double *)(b11 + (7 * cs_b)), _mm256_castpd256_pd128(ymm16)); + + // extract a66 + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (6 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (5 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm16, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm16, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm16, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 6)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm16, ymm9); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + _mm_storeu_pd((double *)(b11 + (6 * cs_b)), _mm256_castpd256_pd128(ymm15)); + + // extract a55 + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (5 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm15, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm15, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm15, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm15, ymm9); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + _mm_storeu_pd((double *)(b11 + (5 * cs_b)), _mm256_castpd256_pd128(ymm14)); + + // extract a44 + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (4 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm14, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm14, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 4)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm14, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + _mm_storeu_pd((double *)(b11 + (4 * cs_b)), _mm256_castpd256_pd128(ymm13)); + + // extract a33 + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (3 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm13, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm13, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm13, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm13, ymm9); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + _mm_storeu_pd((double *)(b11 + (3 * cs_b)), _mm256_castpd256_pd128(ymm12)); + + // extract a22 + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (2 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + _mm_storeu_pd((double *)(b11 + (2 * cs_b)), _mm256_castpd256_pd128(ymm11)); + + // extract a11 + ymm1 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (1 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + _mm_storeu_pd((double *)(b11 + (1 * cs_b)), _mm256_castpd256_pd128(ymm10)); + + // extract a00 + ymm1 = _mm256_broadcast_sd((a11 + (1 * cs_a) + (0 * rs_a))); + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm_storeu_pd((double *)(b11 + (0 * cs_b)), _mm256_castpd256_pd128(ymm9)); + m_remainder -= 2; + } + else if (m_remainder == 1) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + b10 = B + (j + d_nr) * cs_b; //pointer to block of B to be used in GEMM + b11 = B + (j * cs_b); //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_YMM_REG_ZEROS_AVX512 + BLIS_DTRSM_SMALL_GEMM_8nx1m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11); + BLIS_PRE_DTRSM_SMALL_8x1(AlphaVal, b11, cs_b) + /* + Compute 8x1 TRSM block by using GEMM block output in register + a. The 8x1 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + */ + + // extract a77 + ymm0 = _mm256_broadcast_sd((d11_pack + 7)); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm0); + _mm_storel_pd((double *)(b11 + (7 * cs_b)), _mm256_castpd256_pd128(ymm16)); + + // extract a66 + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (6 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (5 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm16, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm16, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm16, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 6)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm16, ymm9); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + _mm_storel_pd((double *)(b11 + (6 * cs_b)), _mm256_castpd256_pd128(ymm15)); + + // extract a55 + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (5 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm15, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm15, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm15, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm15, ymm9); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + _mm_storel_pd((double *)(b11 + (5 * cs_b)), _mm256_castpd256_pd128(ymm14)); + + // extract a44 + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (4 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm14, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm14, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 4)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm14, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + _mm_storel_pd((double *)(b11 + (4 * cs_b)), _mm256_castpd256_pd128(ymm13)); + + // extract a33 + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (3 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm13, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm13, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm13, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm13, ymm9); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + _mm_storel_pd((double *)(b11 + (3 * cs_b)), _mm256_castpd256_pd128(ymm12)); + + // extract a22 + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (2 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + _mm_storel_pd((double *)(b11 + (2 * cs_b)), _mm256_castpd256_pd128(ymm11)); + + // extract a11 + ymm1 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (1 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + _mm_storel_pd((double *)(b11 + (1 * cs_b)), _mm256_castpd256_pd128(ymm10)); + + // extract a00 + ymm1 = _mm256_broadcast_sd((a11 + (1 * cs_a) + (0 * rs_a))); + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm_storel_pd((double *)(b11 + (0 * cs_b)), _mm256_castpd256_pd128(ymm9)); + m_remainder -= 1; + } + } + } + + dim_t n_remainder = j + d_nr; + + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (8x8) + above holds for reminder cases too. + */ + + if (n_remainder >= 4) + { + a01 = L + (n_remainder - 4) * rs_a + n_remainder * cs_a; // pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n - n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + for (dim_t x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); + bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); + bli_dcopys(*(a01 + rs_a * 2), *(ptr_a10_dup + (p_lda * 2))); + bli_dcopys(*(a01 + rs_a * 3), *(ptr_a10_dup + (p_lda * 3))); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = (n - n_remainder) / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 1) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 2) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 3) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 3) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 1) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 2) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 3) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 3) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + // read diagonal from a11 if not unit diagonal + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2) + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3) + 3)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 2) + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 3) + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx8m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + (cs_b + 4)), ymm6); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2) + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + (cs_b * 3)), ymm9); + _mm256_storeu_pd((double *)(b11 + (cs_b * 3) + 4), ymm10); + } + + dim_t m_remainder = i + d_mr; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 3))); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + _mm256_storeu_pd((double *)(b11 + (cs_b * 3)), ymm9); + + m_remainder -= 4; + } + + if (m_remainder) + { + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx3m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 3))); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 3) + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + (cs_b * 2)), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + (cs_b * 3)), _mm256_castpd256_pd128(ymm9)); + + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3, 1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5, 1)); + _mm_storel_pd((double *)(b11 + (cs_b * 2) + 2), _mm256_extractf128_pd(ymm7, 1)); + _mm_storel_pd((double *)(b11 + (cs_b * 3) + 2), _mm256_extractf128_pd(ymm9, 1)); + + m_remainder -= 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx2m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 3))); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + (cs_b * 2)), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + (cs_b * 3)), _mm256_castpd256_pd128(ymm9)); + + m_remainder -= 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx1m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);// register to hold alpha + + ymm0 = _mm256_broadcast_sd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 3))); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storel_pd((b11 + (cs_b * 0)), _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((b11 + (cs_b * 1)), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((b11 + (cs_b * 2)), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((b11 + (cs_b * 3)), _mm256_castpd256_pd128(ymm9)); + + m_remainder -= 1; + } + } + n_remainder -= 4; + } + + if (n_remainder == 3) + { + a01 = L + 3*cs_a; // pointer to block of A to be used in GEMM + a11 = L; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n - 3); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + for (dim_t x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + p_lda * 0)); + bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + p_lda * 1)); + bli_dcopys(*(a01 + rs_a * 2), *(ptr_a10_dup + p_lda * 2)); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = (n - 3) / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 1) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 2) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 1) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 2) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2) + 2)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 2) + 2)); + } + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + i + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B + i; // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx8m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); + // B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); + // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); + // B11[4-7][1] * alpha -= ymm3 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2) + 4)); + // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); + // B11[4-7][2] * alpha -= ymm5 + + /// implement TRSM/// + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2) + 4), ymm8); + } + + dim_t m_remainder = i + d_mr; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4); // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + /// implement TRSM/// + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + + m_remainder -= 4; + } + + if (m_remainder) + { + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx3m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_3M(b11, cs_b) + + m_remainder -= 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx2m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_2M(b11, cs_b) + + m_remainder -= 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx1m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_1M(b11, cs_b) + + m_remainder -= 1; + } + } + n_remainder -= 3; + } +else if ( n_remainder == 2) + { + a01 = L + 2*cs_a; // pointer to block of A to be used in GEMM + a11 = L; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n - 2); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + for (dim_t x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); + bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = (n - 2) / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 1) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 1) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1) + 1)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 1) + 1)); + } + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + i + 2*cs_b; // pointer to block of B to be used in GEMM + b11 = B + i; // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx8m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); + // B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); + // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); + // B11[4-7][1] * alpha -= ymm3 + + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + } + + dim_t m_remainder = i + d_mr; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + 2*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4); // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + + m_remainder -= 4; + } + + if (m_remainder) + { + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 2*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx3m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_3M(b11, cs_b) + + m_remainder -= 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + (2)*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx2m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal, b11, cs_b) + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_2M(b11, cs_b) + + m_remainder -= 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 2*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx1m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal, b11, cs_b) + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_1M(b11, cs_b) + + m_remainder -= 1; + } + } + n_remainder -= 2; + } + else if ( n_remainder == 1) + { + a01 = L + 1 * cs_a; // pointer to block of A to be used in GEMM + a11 = L; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n - 1); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + for (dim_t x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01), *(ptr_a10_dup)); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = (n - 1) / 4; + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + i + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B + i; // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx8m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); + // B11[4-7][0] * alpha-= ymm1 + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + } + + dim_t m_remainder = i + d_mr; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4); // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + + m_remainder -= 4; + } + + if (m_remainder) + { + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx3m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); + + BLIS_POST_DTRSM_SMALL_1N_3M(b11, cs_b) + + m_remainder -= 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx2m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_2M(b11, cs_b) + + m_remainder -= 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx1m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_1M(b11, cs_b) + + } + } + } + + if ((required_packing_A) && bli_mem_is_alloc(&local_mem_buf_A_s)) + { + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + return BLIS_SUCCESS; } // LLNN - LUTN