diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index b5f295854..680e3c1bf 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -956,9 +956,6 @@ void dtrsm_blis_impl // Query the architecture ID arch_t id = bli_arch_query_id(); -#if defined(BLIS_KERNELS_ZEN4) - bool uplo, transa; -#endif switch(id) { case BLIS_ARCH_ZEN4: @@ -969,11 +966,8 @@ void dtrsm_blis_impl // for n < 200 avx2 kernels are performing better, but if // n is a multiple of 8 then there will be no fringe case for avx512, // in such cases avx512 kernels will perform better. - uplo = bli_obj_is_upper(&ao); - transa = bli_obj_has_trans(&ao); - if(( ((blis_side == BLIS_RIGHT) && (uplo == true) && (transa == false)) || - ((blis_side == BLIS_RIGHT) && (uplo == false) && (transa == true))) && - ((n0 > 400) && (m0 > 50))) + if( (blis_side == BLIS_RIGHT) && + ((n0 > 300) && (m0 > 50))) { status = bli_trsm_small_AVX512( blis_side, diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c index 9912c8e4a..0a70ef5f0 100644 --- a/kernels/zen4/3/bli_trsm_small_AVX512.c +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -120,6 +120,17 @@ zmm30 = _mm512_setzero_pd(); \ zmm31 = _mm512_setzero_pd(); +#define BLIS_SET_YMM_REG_ZEROS_FOR_N_REM \ + ymm3 = _mm256_setzero_pd(); \ + ymm4 = _mm256_setzero_pd(); \ + ymm5 = _mm256_setzero_pd(); \ + ymm6 = _mm256_setzero_pd(); \ + ymm7 = _mm256_setzero_pd(); \ + ymm8 = _mm256_setzero_pd(); \ + ymm9 = _mm256_setzero_pd(); \ + ymm10 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + /* declaration of trsm small kernels function pointer */ @@ -1067,7 +1078,7 @@ err_t bli_trsm_small_mt_AVX512 } \ for (; itr2 > 0; itr2--) \ { \ - ymm23 = _mm256_broadcast_sd(b10); \ + ymm23 = _mm256_broadcast_sd(b10_2); \ \ ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 0))); \ ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 1))); \ @@ -2731,8 +2742,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if (transa) { - dim_t x = 0; - for (x = 0; x < p_lda; x += 1) + for (dim_t x = 0; x < p_lda; x += 1) { bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); @@ -2817,15 +2827,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx8m(a01, b10, cs_b, p_lda, k_iter) @@ -2916,15 +2918,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx4m(a01, b10, cs_b, p_lda, k_iter) @@ -3016,15 +3010,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx3m(a01, b10, cs_b, p_lda, k_iter) @@ -3121,15 +3107,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx2m(a01, b10, cs_b, p_lda, k_iter) @@ -3220,15 +3198,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_4nx1m(a01, b10, cs_b, p_lda, k_iter) @@ -3325,8 +3295,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if (transa) { - dim_t x = 0; - for (x = 0; x < p_lda; x += 1) + for (dim_t x = 0; x < p_lda; x += 1) { bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + p_lda * 0)); bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + p_lda * 1)); @@ -3405,15 +3374,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx8m(a01, b10, cs_b, p_lda, k_iter) @@ -3508,15 +3469,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx4m(a01, b10, cs_b, p_lda, k_iter) @@ -3584,15 +3537,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx3m(a01, b10, cs_b, p_lda, k_iter) @@ -3643,15 +3588,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx2m(a01, b10, cs_b, p_lda, k_iter) @@ -3702,15 +3639,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_3nx1m(a01, b10, cs_b, p_lda, k_iter) @@ -3766,8 +3695,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if (transa) { - dim_t x = 0; - for (x = 0; x < p_lda; x += 1) + for (dim_t x = 0; x < p_lda; x += 1) { bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); @@ -3840,15 +3768,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) /*Fill zeros into ymm registers used in gemm accumulations */ - ymm3 = _mm256_setzero_pd(); - ymm4 = _mm256_setzero_pd(); - ymm5 = _mm256_setzero_pd(); - ymm6 = _mm256_setzero_pd(); - ymm7 = _mm256_setzero_pd(); - ymm8 = _mm256_setzero_pd(); - ymm9 = _mm256_setzero_pd(); - ymm10 = _mm256_setzero_pd(); - ymm15 = _mm256_setzero_pd(); + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM /// GEMM implementation starts/// BLIS_DTRSM_SMALL_GEMM_2nx8m(a01, b10, cs_b, p_lda, k_iter) @@ -4077,8 +3997,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 if (transa) { - dim_t x = 0; - for (x = 0; x < p_lda; x += 1) + for (dim_t x = 0; x < p_lda; x += 1) { bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + p_lda * 0)); ptr_a10_dup += 1; @@ -4301,7 +4220,2286 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB_AVX512 cntl_t* cntl ) { - return BLIS_NOT_YET_IMPLEMENTED; + dim_t m = bli_obj_length(b); // number of rows + dim_t n = bli_obj_width(b); // number of columns + dim_t d_mr = 8, d_nr = 8; + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + double ones = 1.0; + + // Swap rs_a & cs_a in case of non-transpose. + if (transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; + dim_t k_iter; + + bool is_unitdiag = bli_obj_has_unit_diag(a); + + double AlphaVal = *(double *)AlphaObj->buffer; + double *restrict L = bli_obj_buffer_at_off(a); // pointer to matrix A + double *B = bli_obj_buffer_at_off(b); // pointer to matrix B + + double *a01, *a11, *b10, *b11; // pointers for GEMM and TRSM blocks + + bool required_packing_A = true; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; // pointer to A01 pack buffer + double d11_pack[d_mr] __attribute__((aligned(64))); // buffer for diagonal A pack + rntm_t rntm; + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ((d_nr * n * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); // acquire memory for A01 pack + if (FALSE == bli_mem_is_alloc(&local_mem_buf_A_s)) + return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if (NULL == D_A_pack) + return BLIS_NULL_POINTER; + } + __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7, zmm8, zmm9, zmm10, zmm11; + __m512d zmm12, zmm13, zmm14, zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512d zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15, ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm22, ymm23, ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31; + __m128d xmm5, xmm0; + + /* + Performs solving TRSM for 8 rows at a time from 0 to n/8 in steps of d_nr + a. Load and pack A (a01 block), the size of packing 8x8 to 8x(n-8) + First there will be no GEMM and no packing of a01 because it is only TRSM + b. Using packed a01 block and b10 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B + d. Repeat b for m cols of B in steps of d_mr + */ + for (j = (n - d_nr); j > -1; j -= d_nr) //loop along 'N' direction + { + a01 = L + (j * rs_a) + (j + d_nr) * cs_a; //pointer to block of A to be used in GEMM + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + + dim_t p_lda = (n - j - d_nr); //packed leading dimension + + // perform copy of A to packed buffer D_A_pack + if (transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 8x(n-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_dtrsm_small_pack_avx512 + ( + 'R', + p_lda, + 1, + a01, + cs_a, + D_A_pack, + p_lda, + d_nr + ); + /* + Pack 8 diagonal elements of A block into an array + a. This helps to utilize cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element_avx512 + ( + is_unitdiag, + a11, + cs_a, + d11_pack, + d_nr + ); + } + else + { + bli_dtrsm_small_pack_avx512 + ( + 'R', + p_lda, + 0, + a01, + rs_a, + D_A_pack, + p_lda, + d_nr + ); + dtrsm_small_pack_diag_element_avx512 + ( + is_unitdiag, + a11, + rs_a, + d11_pack, + d_nr + ); + } + + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; //pointer to block of A to be used for TRSM + b10 = B + i + (j + d_nr) * cs_b; //pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_ZMM_REG_ZEROS + /* + Perform GEMM between a01 and b10 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8nx8m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11); + /* + Load b11 of size 8x8 and multiply with alpha + Add the GEMM output to b11 + and perform TRSM operation. + */ + BLIS_PRE_DTRSM_SMALL_8x8(AlphaVal, b11, cs_b) + + + + /* + Compute 8x8 TRSM block by using GEMM block output in register + a. The 8x8 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : zmm9 zmm10 zmm11 zmm12 zmm13 zmm14 zmm15 zmm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + * zmm11 -= ((zmm12 * A11[2][3]) + (zmm13 * A11[2][4]) + + * (zmm14 * A11[2][5]) + (zmm15 * A11[2][6]) + + * (zmm16 * A11[2][7])) / A11[2][2] + */ + + // extract a77 + zmm0 = _mm512_set1_pd(*(d11_pack + 7)); + zmm16 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm16, zmm0); + _mm512_storeu_pd((double *)(b11 + 7 * cs_b), zmm16); + + // extract a66 + zmm0 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (6 * rs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (5 * rs_a))); + zmm15 = _mm512_fnmadd_pd(zmm0, zmm16, zmm15); + zmm0 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (4 * rs_a))); + zmm14 = _mm512_fnmadd_pd(zmm1, zmm16, zmm14); + zmm1 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (3 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm16, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (2 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm16, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm16, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (7 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm16, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 6)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm16, zmm9); + zmm15 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm15, zmm1); + _mm512_storeu_pd((double *)(b11 + (6 * cs_b)), zmm15); + + // extract a55 + zmm1 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (5 * rs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (4 * rs_a))); + zmm14 = _mm512_fnmadd_pd(zmm1, zmm15, zmm14); + zmm1 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (3 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm15, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (2 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm15, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm15, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (6 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm15, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 5)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm15, zmm9); + zmm14 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm14, zmm1); + _mm512_storeu_pd((double *)(b11 + (5 * cs_b)), zmm14); + + // extract a44 + zmm0 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (4 * rs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (3 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm14, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (2 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm14, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm14, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (5 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm14, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 4)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm14, zmm9); + zmm13 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm13, zmm1); + _mm512_storeu_pd((double *)(b11 + (4 * cs_b)), zmm13); + + // extract a33 + zmm1 = _mm512_set1_pd(*(a11 + (4 * cs_a) + (3 * rs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (4 * cs_a) + (2 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm13, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (4 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm13, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (4 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm13, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 3)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm13, zmm9); + zmm12 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm12, zmm1); + _mm512_storeu_pd((double *)(b11 + (3 * cs_b)), zmm12); + + // extract a22 + zmm0 = _mm512_set1_pd(*(a11 + (3 * cs_a) + (2 * rs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (3 * cs_a) + (1 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm12, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (3 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm12, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 2)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm12, zmm9); + zmm11 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm11, zmm1); + _mm512_storeu_pd((double *)(b11 + (2 * cs_b)), zmm11); + + // extract a11 + zmm1 = _mm512_set1_pd(*(a11 + (2 * cs_a) + (1 * rs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (2 * cs_a) + (0 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm11, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 1)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm11, zmm9); + zmm10 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm10, zmm1); + _mm512_storeu_pd((double *)(b11 + (1 * cs_b)), zmm10); + + // extract a00 + zmm1 = _mm512_set1_pd(*(a11 + (1 * cs_a) + (0 * rs_a))); + zmm0 = _mm512_set1_pd(*(d11_pack + 0)); + zmm9 = _mm512_fnmadd_pd(zmm1, zmm10, zmm9); + zmm9 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm9, zmm0); + _mm512_storeu_pd((double *)(b11 + (0 * cs_b)), zmm9); + } + dim_t m_remainder = i + d_mr; + if(m_remainder) + { + if (m_remainder >= 4) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (j + d_nr) * cs_b; //pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (j * cs_b); //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_YMM_REG_ZEROS_AVX512 + /* + Perform GEMM between a01 and b10 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8nx4m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + /* + Load b11 of size 8x4 and multiply with alpha + Add the GEMM output to b11 + and perform TRSM operation. + */ + BLIS_PRE_DTRSM_SMALL_8x4(AlphaVal, b11, cs_b) + + /* + Compute 8x4 TRSM block by using GEMM block output in register + a. The 8x4 input (gemm outputs) are stored in combinations of ymm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + */ + + // extract a77 + ymm0 = _mm256_broadcast_sd((d11_pack + 7)); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm0); + _mm256_storeu_pd((double *)(b11 + (7 * cs_b)), ymm16); + + // extract a66 + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (6 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (5 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm16, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm16, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm16, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 6)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm16, ymm9); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + _mm256_storeu_pd((double *)(b11 + (6 * cs_b)), ymm15); + + // extract a55 + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (5 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm15, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm15, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm15, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm15, ymm9); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + _mm256_storeu_pd((double *)(b11 + (5 * cs_b)), ymm14); + + // extract a44 + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (4 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm14, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm14, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 4)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm14, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + _mm256_storeu_pd((double *)(b11 + (4 * cs_b)), ymm13); + + // extract a33 + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (3 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm13, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm13, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm13, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm13, ymm9); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + _mm256_storeu_pd((double *)(b11 + (3 * cs_b)), ymm12); + + // extract a22 + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (2 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + _mm256_storeu_pd((double *)(b11 + (2 * cs_b)), ymm11); + + // extract a11 + ymm1 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (1 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + _mm256_storeu_pd((double *)(b11 + (1 * cs_b)), ymm10); + + // extract a00 + ymm1 = _mm256_broadcast_sd((a11 + (1 * cs_a) + (0 * rs_a))); + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm256_storeu_pd((double *)(b11 + (0 * cs_b)), ymm9); + m_remainder -= 4; + } + if (m_remainder == 3) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + b10 = B + (j + d_nr) * cs_b; // pointer to block of B to be used in GEMM + b11 = B + (j * cs_b); //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_YMM_REG_ZEROS_AVX512 + /* + Perform GEMM between a01 and b10 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8nx3m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + /* + Load b11 of size 8x3 and multiply with alpha + Add the GEMM output to b11 + and perform TRSM operation. + */ + BLIS_PRE_DTRSM_SMALL_8x3(AlphaVal, b11, cs_b) + /* + Compute 8x3 TRSM block by using GEMM block output in register + a. The 8x3 input (gemm outputs) are stored in combinations of ymm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + */ + + // extract a77 + ymm0 = _mm256_broadcast_sd((d11_pack + 7)); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm0); + _mm_storeu_pd((double *)(b11 + (7 * cs_b) + 0), _mm256_castpd256_pd128(ymm16)); + _mm_storel_pd((double *)(b11 + (7 * cs_b) + 2), _mm256_extractf64x2_pd(ymm16, 1)); + + // extract a66 + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (6 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (5 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm16, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm16, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm16, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 6)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm16, ymm9); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + _mm_storeu_pd((double *)(b11 + (6 * cs_b) + 0), _mm256_castpd256_pd128(ymm15)); + _mm_storel_pd((double *)(b11 + (6 * cs_b) + 2), _mm256_extractf64x2_pd(ymm15, 1)); + + // extract a55 + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (5 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm15, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm15, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm15, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm15, ymm9); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + _mm_storeu_pd((double *)(b11 + (5 * cs_b) + 0), _mm256_castpd256_pd128(ymm14)); + _mm_storel_pd((double *)(b11 + (5 * cs_b) + 2), _mm256_extractf64x2_pd(ymm14, 1)); + + // extract a44 + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (4 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm14, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm14, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 4)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm14, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + _mm_storeu_pd((double *)(b11 + (4 * cs_b) + 0), _mm256_castpd256_pd128(ymm13)); + _mm_storel_pd((double *)(b11 + (4 * cs_b) + 2), _mm256_extractf64x2_pd(ymm13, 1)); + + // extract a33 + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (3 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm13, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm13, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm13, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm13, ymm9); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + _mm_storeu_pd((double *)(b11 + (3 * cs_b) + 0), _mm256_castpd256_pd128(ymm12)); + _mm_storel_pd((double *)(b11 + (3 * cs_b) + 2), _mm256_extractf64x2_pd(ymm12, 1)); + + // extract a22 + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (2 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + _mm_storeu_pd((double *)(b11 + (2 * cs_b) + 0), _mm256_castpd256_pd128(ymm11)); + _mm_storel_pd((double *)(b11 + (2 * cs_b) + 2), _mm256_extractf64x2_pd(ymm11, 1)); + + // extract a11 + ymm1 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (1 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + _mm_storeu_pd((double *)(b11 + (1 * cs_b) + 0), _mm256_castpd256_pd128(ymm10)); + _mm_storel_pd((double *)(b11 + (1 * cs_b) + 2), _mm256_extractf64x2_pd(ymm10, 1)); + + // extract a00 + ymm1 = _mm256_broadcast_sd((a11 + (1 * cs_a) + (0 * rs_a))); + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm_storeu_pd((double *)(b11 + (0 * cs_b) + 0), _mm256_castpd256_pd128(ymm9)); + _mm_storel_pd((double *)(b11 + (0 * cs_b) + 2), _mm256_extractf64x2_pd(ymm9, 1)); + m_remainder -= 3; + } + else if (m_remainder == 2) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + b10 = B + (j + d_nr) * cs_b; // pointer to block of B to be used in GEMM + b11 = B + (j * cs_b); //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_YMM_REG_ZEROS_AVX512 + BLIS_DTRSM_SMALL_GEMM_8nx2m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + BLIS_PRE_DTRSM_SMALL_8x2(AlphaVal, b11, cs_b) + /* + Compute 8x2 TRSM block by using GEMM block output in register + a. The 8x2 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + */ + + // extract a77 + ymm0 = _mm256_broadcast_sd((d11_pack + 7)); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm0); + _mm_storeu_pd((double *)(b11 + (7 * cs_b)), _mm256_castpd256_pd128(ymm16)); + + // extract a66 + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (6 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (5 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm16, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm16, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm16, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 6)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm16, ymm9); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + _mm_storeu_pd((double *)(b11 + (6 * cs_b)), _mm256_castpd256_pd128(ymm15)); + + // extract a55 + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (5 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm15, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm15, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm15, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm15, ymm9); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + _mm_storeu_pd((double *)(b11 + (5 * cs_b)), _mm256_castpd256_pd128(ymm14)); + + // extract a44 + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (4 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm14, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm14, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 4)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm14, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + _mm_storeu_pd((double *)(b11 + (4 * cs_b)), _mm256_castpd256_pd128(ymm13)); + + // extract a33 + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (3 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm13, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm13, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm13, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm13, ymm9); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + _mm_storeu_pd((double *)(b11 + (3 * cs_b)), _mm256_castpd256_pd128(ymm12)); + + // extract a22 + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (2 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + _mm_storeu_pd((double *)(b11 + (2 * cs_b)), _mm256_castpd256_pd128(ymm11)); + + // extract a11 + ymm1 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (1 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + _mm_storeu_pd((double *)(b11 + (1 * cs_b)), _mm256_castpd256_pd128(ymm10)); + + // extract a00 + ymm1 = _mm256_broadcast_sd((a11 + (1 * cs_a) + (0 * rs_a))); + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm_storeu_pd((double *)(b11 + (0 * cs_b)), _mm256_castpd256_pd128(ymm9)); + m_remainder -= 2; + } + else if (m_remainder == 1) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (j * cs_a) + (j * rs_a); //pointer to block of A to be used for TRSM + b10 = B + (j + d_nr) * cs_b; //pointer to block of B to be used in GEMM + b11 = B + (j * cs_b); //pointer to block of B to be used for TRSM + + k_iter = (n - j - d_nr); + BLIS_SET_YMM_REG_ZEROS_AVX512 + BLIS_DTRSM_SMALL_GEMM_8nx1m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11); + BLIS_PRE_DTRSM_SMALL_8x1(AlphaVal, b11, cs_b) + /* + Compute 8x1 TRSM block by using GEMM block output in register + a. The 8x1 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to i=7 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[N][i] } /A11[N][N] + * from i=n+1 + * + * For example 3rd column (B11[2]) -= ((B11[3] * A11[2][3]) + (B11[4] * A11[2][4]) + + * (B11[5] * A11[2][5]) + (B11[6] * A11[2][6]) + + * (B11[7] * A11[2][7])) / A11[2][2] + */ + + // extract a77 + ymm0 = _mm256_broadcast_sd((d11_pack + 7)); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm0); + _mm_storel_pd((double *)(b11 + (7 * cs_b)), _mm256_castpd256_pd128(ymm16)); + + // extract a66 + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (6 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (5 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm16, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm16, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (7 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm16, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 6)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm16, ymm9); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + _mm_storel_pd((double *)(b11 + (6 * cs_b)), _mm256_castpd256_pd128(ymm15)); + + // extract a55 + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (5 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (4 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm15, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm15, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (6 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm15, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm15, ymm9); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + _mm_storel_pd((double *)(b11 + (5 * cs_b)), _mm256_castpd256_pd128(ymm14)); + + // extract a44 + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (4 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (3 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm14, ymm13); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm14, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm14, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (5 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm14, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 4)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm14, ymm9); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + _mm_storel_pd((double *)(b11 + (4 * cs_b)), _mm256_castpd256_pd128(ymm13)); + + // extract a33 + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (3 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (2 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm1, ymm13, ymm12); + ymm1 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm13, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (4 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm13, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm13, ymm9); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + _mm_storel_pd((double *)(b11 + (3 * cs_b)), _mm256_castpd256_pd128(ymm12)); + + // extract a22 + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (2 * rs_a))); + ymm1 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (1 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm0 = _mm256_broadcast_sd((a11 + (3 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + _mm_storel_pd((double *)(b11 + (2 * cs_b)), _mm256_castpd256_pd128(ymm11)); + + // extract a11 + ymm1 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (1 * rs_a))); + ymm0 = _mm256_broadcast_sd((a11 + (2 * cs_a) + (0 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + _mm_storel_pd((double *)(b11 + (1 * cs_b)), _mm256_castpd256_pd128(ymm10)); + + // extract a00 + ymm1 = _mm256_broadcast_sd((a11 + (1 * cs_a) + (0 * rs_a))); + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm_storel_pd((double *)(b11 + (0 * cs_b)), _mm256_castpd256_pd128(ymm9)); + m_remainder -= 1; + } + } + } + + dim_t n_remainder = j + d_nr; + + /* + Reminder cases starts here: + a. Similar logic and code flow used in computing full block (8x8) + above holds for reminder cases too. + */ + + if (n_remainder >= 4) + { + a01 = L + (n_remainder - 4) * rs_a + n_remainder * cs_a; // pointer to block of A to be used in GEMM + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n - n_remainder); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + for (dim_t x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); + bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); + bli_dcopys(*(a01 + rs_a * 2), *(ptr_a10_dup + (p_lda * 2))); + bli_dcopys(*(a01 + rs_a * 3), *(ptr_a10_dup + (p_lda * 3))); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = (n - n_remainder) / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 1) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 2) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 3) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 3) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 1) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 2) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 3) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 3) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + // read diagonal from a11 if not unit diagonal + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2) + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3) + 3)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 2) + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 3) + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (i) + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx8m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm10, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm10, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm10, ymm4); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + (cs_b + 4)), ymm6); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2) + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + (cs_b * 3)), ymm9); + _mm256_storeu_pd((double *)(b11 + (cs_b * 3) + 4), ymm10); + } + + dim_t m_remainder = i + d_mr; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4) + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 3))); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + _mm256_storeu_pd((double *)(b11 + (cs_b * 3)), ymm9); + + m_remainder -= 4; + } + + if (m_remainder) + { + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx3m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 3))); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 3) + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + (cs_b * 2)), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + (cs_b * 3)), _mm256_castpd256_pd128(ymm9)); + + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3, 1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5, 1)); + _mm_storel_pd((double *)(b11 + (cs_b * 2) + 2), _mm256_extractf128_pd(ymm7, 1)); + _mm_storel_pd((double *)(b11 + (cs_b * 3) + 2), _mm256_extractf128_pd(ymm9, 1)); + + m_remainder -= 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx2m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 3))); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + (cs_b * 2)), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + (cs_b * 3)), _mm256_castpd256_pd128(ymm9)); + + m_remainder -= 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + (n_remainder - 4) * cs_a + (n_remainder - 4) * rs_a; // pointer to block of A to be used for TRSM + b10 = B + (n_remainder)*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (n_remainder - 4) * cs_b; // pointer to block of B to be used for TRSM + + k_iter = (n - n_remainder); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx1m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal);// register to hold alpha + + ymm0 = _mm256_broadcast_sd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 3))); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm9, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm9, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm9, ymm3); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm_storel_pd((b11 + (cs_b * 0)), _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((b11 + (cs_b * 1)), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((b11 + (cs_b * 2)), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((b11 + (cs_b * 3)), _mm256_castpd256_pd128(ymm9)); + + m_remainder -= 1; + } + } + n_remainder -= 4; + } + + if (n_remainder == 3) + { + a01 = L + 3*cs_a; // pointer to block of A to be used in GEMM + a11 = L; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n - 3); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + for (dim_t x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + p_lda * 0)); + bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + p_lda * 1)); + bli_dcopys(*(a01 + rs_a * 2), *(ptr_a10_dup + p_lda * 2)); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = (n - 3) / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 1) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 2) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 1) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 2) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2) + 2)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 2) + 2)); + } + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + i + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B + i; // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx8m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); + // B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); + // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); + // B11[4-7][1] * alpha -= ymm3 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2) + 4)); + // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); + // B11[4-7][2] * alpha -= ymm5 + + /// implement TRSM/// + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm8, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm8, ymm4); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2) + 4), ymm8); + } + + dim_t m_remainder = i + d_mr; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4); // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + /// implement TRSM/// + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + + m_remainder -= 4; + } + + if (m_remainder) + { + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx3m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_3M(b11, cs_b) + + m_remainder -= 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx2m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_2M(b11, cs_b) + + m_remainder -= 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 3*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 3); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx1m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a) + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm7, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * cs_a))); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm7, ymm3); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_1M(b11, cs_b) + + m_remainder -= 1; + } + } + n_remainder -= 3; + } +else if ( n_remainder == 2) + { + a01 = L + 2*cs_a; // pointer to block of A to be used in GEMM + a11 = L; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n - 2); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + for (dim_t x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); + bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = (n - 2) / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 1) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 1) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1) + 1)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 1) + 1)); + } + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + i + 2*cs_b; // pointer to block of B to be used in GEMM + b11 = B + i; // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_N_REM + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx8m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); + // B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); + // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); + // B11[4-7][1] * alpha -= ymm3 + + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + ymm4 = _mm256_fnmadd_pd(ymm1, ymm6, ymm4); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + } + + dim_t m_remainder = i + d_mr; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + 2*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4); // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + + m_remainder -= 4; + } + + if (m_remainder) + { + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 2*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx3m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_3M(b11, cs_b) + + m_remainder -= 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + (2)*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx2m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal, b11, cs_b) + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_2M(b11, cs_b) + + m_remainder -= 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 2*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 2); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx1m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal, b11, cs_b) + /// implement TRSM/// + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(Row 1): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + cs_a)); + ymm3 = _mm256_fnmadd_pd(ymm1, ymm5, ymm3); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_1M(b11, cs_b) + + m_remainder -= 1; + } + } + n_remainder -= 2; + } + else if ( n_remainder == 1) + { + a01 = L + 1 * cs_a; // pointer to block of A to be used in GEMM + a11 = L; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = (n - 1); // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + for (dim_t x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01), *(ptr_a10_dup)); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = (n - 1) / 4; + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + i + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B + i; // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx8m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); + // B11[4-7][0] * alpha-= ymm1 + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + } + + dim_t m_remainder = i + d_mr; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + (m_remainder - 4) + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B + (m_remainder - 4); // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + + m_remainder -= 4; + } + + if (m_remainder) + { + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx3m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + ymm0 = _mm256_loadu_pd((double const *)b11); + ymm3 = _mm256_blend_pd(ymm6, ymm3, 0x07); + + BLIS_POST_DTRSM_SMALL_1N_3M(b11, cs_b) + + m_remainder -= 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx2m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_2M(b11, cs_b) + + m_remainder -= 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b10 = B + 1*cs_b; // pointer to block of B to be used in GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (n - 1); // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx1m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_1M(b11, cs_b) + + } + } + } + + if ((required_packing_A) && bli_mem_is_alloc(&local_mem_buf_A_s)) + { + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + return BLIS_SUCCESS; } // LLNN - LUTN