From 6ad387c2aaaf9ce8b280bf7990dc6bca5c7ac647 Mon Sep 17 00:00:00 2001 From: Aayush Kumar Date: Wed, 12 Apr 2023 04:38:15 +0000 Subject: [PATCH] Added DTRSM Small Path AVX512 based LUNN/LLTN Variant Kernels - 8x8 kernels are used for DTRSM SMALL - Matrix A(a10) is packed for GEMM operations. - Packed martix A will be re-used in all the col-block along N-dimension. - Diagonal elements of A matrix are packed(a11) for TRSM operations. - Implemented fringe cases with below block sizes 8x8, 8x4, 8x3, 8x2, 8x1 4x8, 4x4, 4x3, 4x2, 4x1 3x8, 3x4, 3x3, 3x2, 3x1 2x8, 2x4, 2x3, 2x2, 2x1 1x8, 1x4, 1x3, 1x2, 1x1 AMD-Internal: [CPUPL-2745] Change-Id: I5bb57501f6d3783eb654e375d63901467dd14734 --- frame/compat/bla_trsm_amd.c | 15 +- kernels/zen4/3/bli_trsm_small_AVX512.c | 1791 +++++++++++++++++++++++- 2 files changed, 1794 insertions(+), 12 deletions(-) diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index f58aa3710..98af8991f 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -1059,9 +1059,9 @@ void dtrsm_blis_impl { case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) - // this is a temporary fix, will be removed when all variants are added - if( ((blis_side == BLIS_RIGHT) && ((n0 > 300) && (m0 > 50))) || - ((blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_NO_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_TRANSPOSE) ) ) && ((n0 != 30 && n0 !=60 ) && (m0 > 50))) ) + /* For sizes where m and n < 50,avx2 kernels are performing better, + except for sizes where n is multiple of 8.*/ + if (((n0 % 8 == 0) && (n0 < 50)) || ((m0 > 50) && (n0 > 50))) { ker_ft = bli_trsm_small_AVX512; } @@ -1088,14 +1088,7 @@ void dtrsm_blis_impl { case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) - if ( (blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_NO_TRANSPOSE) ) )) - { - ker_ft = bli_trsm_small_mt; - } - else - { - ker_ft = bli_trsm_small_mt_AVX512; - } + ker_ft = bli_trsm_small_mt_AVX512; break; #endif// BLIS_KERNELS_ZEN4 case BLIS_ARCH_ZEN: diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c index 639ada81e..ba5897b4e 100644 --- a/kernels/zen4/3/bli_trsm_small_AVX512.c +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -9187,7 +9187,1796 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512 cntl_t* cntl ) { - return BLIS_NOT_YET_IMPLEMENTED; + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); // number of columns + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8, d_nr = 8; + + // Swap rs_a & cs_a in case of non-tranpose. + if (transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of B + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + dim_t i, j, k; + dim_t k_iter; + double AlphaVal = *(double *)AlphaObj->buffer; + double *L = bli_obj_buffer_at_off(a); // pointer to matrix A + double *B = bli_obj_buffer_at_off(b); // pointer to matrix B + + double *a10, *a11, *b01, *b11; // pointers for GEMM and TRSM blocks + + double ones = 1.0; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; // pointer to A01 pack buffer + double d11_pack[d_mr] __attribute__((aligned(64))); // buffer for diagonal A pack + rntm_t rntm; + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ((d_mr * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if (FALSE == bli_mem_is_alloc(&local_mem_buf_A_s)) + return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if (NULL == D_A_pack) + return BLIS_NULL_POINTER; + } + bool is_unitdiag = bli_obj_has_unit_diag(a); + + __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7, zmm8, zmm9, zmm10, zmm11; + __m512d zmm12, zmm13, zmm14, zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512d zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15, ymm16; + __m128d xmm5; + + + /* + Performs solving TRSM for 8 columns at a time from 0 to m/d_mr in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x8 to 8x (m-d_mr) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) + { + a10 = L + (i * cs_a) + (i + d_mr) * rs_a; + a11 = L + (i * cs_a) + (i * rs_a); + + dim_t p_lda = d_mr; + /* + Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack_avx512('L', (m - i - d_mr), transa, a10, bli_obj_col_stride(a) , D_A_pack, p_lda, d_mr); + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element_avx512(is_unitdiag, a11, bli_obj_col_stride(a), d11_pack, d_mr); + + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + b01 = B + (j * cs_b) + i + d_mr; //pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_ZMM_REG_ZEROS + + /* + Perform GEMM between a10 and b01 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8mx8n_AVX512(a10, b01, cs_b, p_lda, k_iter, b11) + + /* + Load b11 of size 8x8 and multiply with alpha + Add the GEMM output and perform in register transpose of b11 + to perform TRSM operation. + */ + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8(b11, cs_b, AlphaVal) + + // extract a77 + zmm0 = _mm512_set1_pd(*(d11_pack + 7)); + zmm16 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm16, zmm0); + + // extract a66 + zmm0 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (6 * cs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (5 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm0, zmm16, zmm15); + zmm0 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (4 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm1, zmm16, zmm14); + zmm1 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (3 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm16, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (2 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm16, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm16, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm16, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 6)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm16, zmm9); + zmm15 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm15, zmm1); + + // extract a55 + zmm1 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (5 * cs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (4 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm1, zmm15, zmm14); + zmm1 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (3 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm15, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (2 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm15, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm15, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm15, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 5)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm15, zmm9); + zmm14 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm14, zmm1); + + // extract a44 + zmm0 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (4 * cs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (3 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm14, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (2 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm14, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm14, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm14, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 4)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm14, zmm9); + zmm13 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm13, zmm1); + + // extract a33 + zmm1 = _mm512_set1_pd(*(a11 + (4 * rs_a) + (3 * cs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (4 * rs_a) + (2 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm13, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (4 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm13, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (4 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm13, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 3)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm13, zmm9); + zmm12 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm12, zmm1); + + // extract a22 + zmm0 = _mm512_set1_pd(*(a11 + (3 * rs_a) + (2 * cs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (3 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm12, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (3 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm12, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 2)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm12, zmm9); + zmm11 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm11, zmm1); + + // extract a11 + zmm1 = _mm512_set1_pd(*(a11 + (2 * rs_a) + (1 * cs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (2 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm11, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 1)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm11, zmm9); + zmm10 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm10, zmm1); + + // extract a00 + zmm1 = _mm512_set1_pd(*(a11 + (1 * rs_a) + (0 * cs_a))); + zmm0 = _mm512_set1_pd(*(d11_pack + 0)); + zmm9 = _mm512_fnmadd_pd(zmm1, zmm10, zmm9); + zmm9 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm9, zmm0); + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8_AND_STORE(b11, cs_b) + _mm512_storeu_pd((double *)(b11 + cs_b * 0), zmm0); + _mm512_storeu_pd((double *)(b11 + cs_b * 1), zmm1); + _mm512_storeu_pd((double *)(b11 + cs_b * 2), zmm2); + _mm512_storeu_pd((double *)(b11 + cs_b * 3), zmm3); + _mm512_storeu_pd((double *)(b11 + cs_b * 4), zmm4); + _mm512_storeu_pd((double *)(b11 + cs_b * 5), zmm5); + _mm512_storeu_pd((double *)(b11 + cs_b * 6), zmm6); + _mm512_storeu_pd((double *)(b11 + cs_b * 7), zmm7); + } + dim_t n_remainder = j + d_nr; + if (n_remainder >= 4) + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); + b01 = B + ((n_remainder - 4) * cs_b) + i + d_mr; + b11 = B + ((n_remainder - 4) * cs_b) + i; + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2 + 4)); + // B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3 + 4)); + // B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); // B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); // B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); // B11[0-3][7] * alpha -= B01[0-3][7] + + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); // B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); // B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13, ymm15, 0x20); // B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13, ymm15, 0x31); // B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); // B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); // B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); // B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); // B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + // perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 * cs_a + 7 * rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 * cs_a + 7 * rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 7 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 7 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 7 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7 * rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + // perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 * cs_a + 6 * rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 6 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 6 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 6 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6 * rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + // perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 5 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 5 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 5 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5 * rs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + // perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 4 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 4 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4 * rs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 3 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); // B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); // B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); // B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); // B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); // B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); // store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); // store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); // store B11[7][0-3] + n_remainder -= 4; + } + + if (n_remainder) // implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); + b01 = B + i + d_mr; + b11 = B + i; + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2 + 4)); + // B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); // B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (2 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (1 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); // B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); // B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13, ymm15, 0x20); // B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13, ymm15, 0x31); // B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); // B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); // B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); // B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); // B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + // perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 * cs_a + 7 * rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 * cs_a + 7 * rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 7 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 7 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 7 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7 * rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + // perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 * cs_a + 6 * rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 6 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 6 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 6 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6 * rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + // perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 5 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 5 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 5 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5 * rs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + // perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 4 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 4 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4 * rs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 3 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); // B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); // B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); // B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); // B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); // B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if (3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); // store B11[6][0-3] + } + else if (2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + } + else if (1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + } + } + } + dim_t m_remainder = i + d_mr; + + if (m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i * cs_a) + (i + 4) * rs_a; // pointer to block of A to be used for GEMM + a11 = L + (i * cs_a) + (i * rs_a); // pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < m - i - 4; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < m - i - 4; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x * rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x * p_lda), ymm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + bli_obj_col_stride(a) * 1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + bli_obj_col_stride(a) * 2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + bli_obj_col_stride(a) * 3 + 3)); + + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + // cols + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b) + i + 4; // pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i; // pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8(b11, cs_b, AlphaVal) + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm1); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a + 2 * cs_a)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a + 1 * cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a + 0 * cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a + 1 * cs_a)); + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a + 0 * cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a + 0 * cs_a)); + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm14, ymm13); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8_AND_STORE(b11, cs_b) + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); + } + dim_t n_remainder = j + d_nr; + if ((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); // pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4) * cs_b) + i + 4; // pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4) * cs_b) + i; // pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 3 * rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3 * rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2 * rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); // store B11[3][0-3] + n_remainder = n_remainder - 4; + } + + if (n_remainder) // implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); + b01 = B + i + 4; + b11 = B + i; + + k_iter = (m - i - 4); + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + if (3 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (2 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (1 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 3 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if (3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + } + else if (2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + } + else if (1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + } + } + m_remainder -= 4; + } + if (m_remainder) + { + + a10 = L + m_remainder * rs_a; + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if (3 == m_remainder) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < m - m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < m - m_remainder; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x * rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x * p_lda), ymm0); + } + } + + // cols + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + ymm0 =_mm256_broadcast_sd((double const *)(&AlphaVal)); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 0 + 2)); + ymm1 = _mm256_insertf64x2(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 1 + 2)); + ymm2 = _mm256_insertf64x2(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 2 + 2)); + ymm3 = _mm256_insertf64x2(ymm3, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm4 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 3 + 2)); + ymm4 = _mm256_insertf64x2(ymm4, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 4)); + ymm5 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 4 + 2)); + ymm5 = _mm256_insertf64x2(ymm5, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 5)); + ymm6 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 5 + 2)); + ymm6 = _mm256_insertf64x2(ymm6, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 6)); + ymm7 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 6 + 2)); + ymm7 = _mm256_insertf64x2(ymm7, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 7)); + ymm8 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 7 + 2)); + ymm8 = _mm256_insertf64x2(ymm8, xmm5, 0); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storeu_pd((double *)(b11 + cs_b * 0), _mm256_extractf64x2_pd(ymm9, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_extractf64x2_pd(ymm10, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_extractf64x2_pd(ymm11, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_extractf64x2_pd(ymm12, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_extractf64x2_pd(ymm13, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_extractf64x2_pd(ymm14, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 6), _mm256_extractf64x2_pd(ymm15, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 7), _mm256_extractf64x2_pd(ymm16, 0)); + + _mm_storel_pd((double *)(b11 + cs_b * 0 + 2), _mm256_extractf64x2_pd(ymm9, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf64x2_pd(ymm10, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf64x2_pd(ymm11, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf64x2_pd(ymm12, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 4 + 2), _mm256_extractf64x2_pd(ymm13, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 5 + 2), _mm256_extractf64x2_pd(ymm14, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 6 + 2), _mm256_extractf64x2_pd(ymm15, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 7 + 2), _mm256_extractf64x2_pd(ymm16, 1)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 8, rs_a, cs_b, is_unitdiag); + } + + dim_t n_remainder = j + d_nr; + if ((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4) * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4) * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11, 1)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + + if (n_remainder) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + m_remainder; // pointer to block of B to be used for GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if (2 == m_remainder) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < m - m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < m - m_remainder; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x * rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x * p_lda), ymm0); + } + } + // cols + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + + ymm0 = _mm256_broadcast_sd((double const *)(&AlphaVal)); + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_insertf64x2(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_insertf64x2(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_insertf64x2(ymm3, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm4 = _mm256_insertf64x2(ymm4, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 4)); + ymm5 = _mm256_insertf64x2(ymm5, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 5)); + ymm6 = _mm256_insertf64x2(ymm6, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 6)); + ymm7 = _mm256_insertf64x2(ymm7, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 7)); + ymm8 = _mm256_insertf64x2(ymm8, xmm5, 0); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storeu_pd((double *)(b11 + cs_b * 0), _mm256_extractf64x2_pd(ymm9, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_extractf64x2_pd(ymm10, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_extractf64x2_pd(ymm11, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_extractf64x2_pd(ymm12, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_extractf64x2_pd(ymm13, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_extractf64x2_pd(ymm14, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 6), _mm256_extractf64x2_pd(ymm15, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 7), _mm256_extractf64x2_pd(ymm16, 0)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 8, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if ((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4) * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4) * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if (n_remainder) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + m_remainder; // pointer to block of B to be used for GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal, b11, cs_b) + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if (1 == m_remainder) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < m - m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < m - m_remainder; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x * rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x * p_lda), ymm0); + } + } + // cols + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + + /// GEMM code ends/// + ymm0 = _mm256_broadcast_sd((double const *)(&AlphaVal)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3)); + ymm5 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 4)); + ymm6 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 5)); + ymm7 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 6)); + ymm8 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 7)); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storel_pd((double *)(b11 + cs_b * 0), _mm256_extractf64x2_pd(ymm9, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf64x2_pd(ymm10, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf64x2_pd(ymm11, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf64x2_pd(ymm12, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 4), _mm256_extractf64x2_pd(ymm13, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 5), _mm256_extractf64x2_pd(ymm14, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 6), _mm256_extractf64x2_pd(ymm15, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 7), _mm256_extractf64x2_pd(ymm16, 0)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 8, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if ((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4) * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4) * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm11, 0)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if (n_remainder) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + m_remainder; // pointer to block of B to be used for GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + if (3 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc(&local_mem_buf_A_s)) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + return BLIS_SUCCESS; } #endif \ No newline at end of file