diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index f58aa3710..98af8991f 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -1059,9 +1059,9 @@ void dtrsm_blis_impl { case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) - // this is a temporary fix, will be removed when all variants are added - if( ((blis_side == BLIS_RIGHT) && ((n0 > 300) && (m0 > 50))) || - ((blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_NO_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_TRANSPOSE) ) ) && ((n0 != 30 && n0 !=60 ) && (m0 > 50))) ) + /* For sizes where m and n < 50,avx2 kernels are performing better, + except for sizes where n is multiple of 8.*/ + if (((n0 % 8 == 0) && (n0 < 50)) || ((m0 > 50) && (n0 > 50))) { ker_ft = bli_trsm_small_AVX512; } @@ -1088,14 +1088,7 @@ void dtrsm_blis_impl { case BLIS_ARCH_ZEN4: #if defined(BLIS_KERNELS_ZEN4) - if ( (blis_side == BLIS_LEFT && ( (blis_uploa == BLIS_LOWER && blis_transa == BLIS_TRANSPOSE) || (blis_uploa == BLIS_UPPER && blis_transa == BLIS_NO_TRANSPOSE) ) )) - { - ker_ft = bli_trsm_small_mt; - } - else - { - ker_ft = bli_trsm_small_mt_AVX512; - } + ker_ft = bli_trsm_small_mt_AVX512; break; #endif// BLIS_KERNELS_ZEN4 case BLIS_ARCH_ZEN: diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c index 639ada81e..ba5897b4e 100644 --- a/kernels/zen4/3/bli_trsm_small_AVX512.c +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -9187,7 +9187,1796 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512 cntl_t* cntl ) { - return BLIS_NOT_YET_IMPLEMENTED; + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); // number of columns + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + dim_t d_mr = 8, d_nr = 8; + + // Swap rs_a & cs_a in case of non-tranpose. + if (transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of B + } + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + dim_t i, j, k; + dim_t k_iter; + double AlphaVal = *(double *)AlphaObj->buffer; + double *L = bli_obj_buffer_at_off(a); // pointer to matrix A + double *B = bli_obj_buffer_at_off(b); // pointer to matrix B + + double *a10, *a11, *b01, *b11; // pointers for GEMM and TRSM blocks + + double ones = 1.0; + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; // pointer to A01 pack buffer + double d11_pack[d_mr] __attribute__((aligned(64))); // buffer for diagonal A pack + rntm_t rntm; + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ((d_mr * m * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); + if (FALSE == bli_mem_is_alloc(&local_mem_buf_A_s)) + return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if (NULL == D_A_pack) + return BLIS_NULL_POINTER; + } + bool is_unitdiag = bli_obj_has_unit_diag(a); + + __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7, zmm8, zmm9, zmm10, zmm11; + __m512d zmm12, zmm13, zmm14, zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512d zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15, ymm16; + __m128d xmm5; + + + /* + Performs solving TRSM for 8 columns at a time from 0 to m/d_mr in steps of d_mr + a. Load, transpose, Pack A (a10 block), the size of packing 8x8 to 8x (m-d_mr) + First there will be no GEMM and no packing of a10 because it is only TRSM + b. Using packed a10 block and b01 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operaton using a11, b11 and update B + d. Repeat b,c for n rows of B in steps of d_nr + */ + + for (i = (m - d_mr); (i + 1) > 0; i -= d_mr) + { + a10 = L + (i * cs_a) + (i + d_mr) * rs_a; + a11 = L + (i * cs_a) + (i * rs_a); + + dim_t p_lda = d_mr; + /* + Load, transpose and pack current A block (a10) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a10 block size will be increasing by d_mr for every next itteration + untill it reaches 8x(m-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all n rows of B matrix + */ + bli_dtrsm_small_pack_avx512('L', (m - i - d_mr), transa, a10, bli_obj_col_stride(a) , D_A_pack, p_lda, d_mr); + + /* + Pack 8 diagonal elements of A block into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element_avx512(is_unitdiag, a11, bli_obj_col_stride(a), d11_pack, d_mr); + + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) + { + a10 = D_A_pack; + b01 = B + (j * cs_b) + i + d_mr; //pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i; //pointer to block of B to be used for TRSM + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_ZMM_REG_ZEROS + + /* + Perform GEMM between a10 and b01 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8mx8n_AVX512(a10, b01, cs_b, p_lda, k_iter, b11) + + /* + Load b11 of size 8x8 and multiply with alpha + Add the GEMM output and perform in register transpose of b11 + to perform TRSM operation. + */ + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8(b11, cs_b, AlphaVal) + + // extract a77 + zmm0 = _mm512_set1_pd(*(d11_pack + 7)); + zmm16 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm16, zmm0); + + // extract a66 + zmm0 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (6 * cs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (5 * cs_a))); + zmm15 = _mm512_fnmadd_pd(zmm0, zmm16, zmm15); + zmm0 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (4 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm1, zmm16, zmm14); + zmm1 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (3 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm16, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (2 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm16, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm16, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (7 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm16, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 6)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm16, zmm9); + zmm15 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm15, zmm1); + + // extract a55 + zmm1 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (5 * cs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (4 * cs_a))); + zmm14 = _mm512_fnmadd_pd(zmm1, zmm15, zmm14); + zmm1 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (3 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm15, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (2 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm15, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm15, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (6 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm15, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 5)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm15, zmm9); + zmm14 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm14, zmm1); + + // extract a44 + zmm0 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (4 * cs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (3 * cs_a))); + zmm13 = _mm512_fnmadd_pd(zmm0, zmm14, zmm13); + zmm0 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (2 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm14, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm14, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (5 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm14, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 4)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm14, zmm9); + zmm13 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm13, zmm1); + + // extract a33 + zmm1 = _mm512_set1_pd(*(a11 + (4 * rs_a) + (3 * cs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (4 * rs_a) + (2 * cs_a))); + zmm12 = _mm512_fnmadd_pd(zmm1, zmm13, zmm12); + zmm1 = _mm512_set1_pd(*(a11 + (4 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm13, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (4 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm13, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 3)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm13, zmm9); + zmm12 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm12, zmm1); + + // extract a22 + zmm0 = _mm512_set1_pd(*(a11 + (3 * rs_a) + (2 * cs_a))); + zmm1 = _mm512_set1_pd(*(a11 + (3 * rs_a) + (1 * cs_a))); + zmm11 = _mm512_fnmadd_pd(zmm0, zmm12, zmm11); + zmm0 = _mm512_set1_pd(*(a11 + (3 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm12, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 2)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm12, zmm9); + zmm11 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm11, zmm1); + + // extract a11 + zmm1 = _mm512_set1_pd(*(a11 + (2 * rs_a) + (1 * cs_a))); + zmm0 = _mm512_set1_pd(*(a11 + (2 * rs_a) + (0 * cs_a))); + zmm10 = _mm512_fnmadd_pd(zmm1, zmm11, zmm10); + zmm1 = _mm512_set1_pd(*(d11_pack + 1)); + zmm9 = _mm512_fnmadd_pd(zmm0, zmm11, zmm9); + zmm10 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm10, zmm1); + + // extract a00 + zmm1 = _mm512_set1_pd(*(a11 + (1 * rs_a) + (0 * cs_a))); + zmm0 = _mm512_set1_pd(*(d11_pack + 0)); + zmm9 = _mm512_fnmadd_pd(zmm1, zmm10, zmm9); + zmm9 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm9, zmm0); + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_8x8_AND_STORE(b11, cs_b) + _mm512_storeu_pd((double *)(b11 + cs_b * 0), zmm0); + _mm512_storeu_pd((double *)(b11 + cs_b * 1), zmm1); + _mm512_storeu_pd((double *)(b11 + cs_b * 2), zmm2); + _mm512_storeu_pd((double *)(b11 + cs_b * 3), zmm3); + _mm512_storeu_pd((double *)(b11 + cs_b * 4), zmm4); + _mm512_storeu_pd((double *)(b11 + cs_b * 5), zmm5); + _mm512_storeu_pd((double *)(b11 + cs_b * 6), zmm6); + _mm512_storeu_pd((double *)(b11 + cs_b * 7), zmm7); + } + dim_t n_remainder = j + d_nr; + if (n_remainder >= 4) + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); + b01 = B + ((n_remainder - 4) * cs_b) + i + d_mr; + b11 = B + ((n_remainder - 4) * cs_b) + i; + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2 + 4)); + // B11[0][6] B11[1][6] B11[2][6] B11[3][6] + ymm7 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3 + 4)); + // B11[0][7] B11[1][7] B11[2][7] B11[3][7] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); // B11[0-3][3] * alpha -= B01[0-3][3] + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); // B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_fmsub_pd(ymm7, ymm16, ymm15); // B11[0-3][7] * alpha -= B01[0-3][7] + + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); // B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); // B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13, ymm15, 0x20); // B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13, ymm15, 0x31); // B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); // B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); // B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); // B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); // B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + // perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 * cs_a + 7 * rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 * cs_a + 7 * rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 7 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 7 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 7 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7 * rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + // perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 * cs_a + 6 * rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 6 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 6 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 6 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6 * rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + // perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 5 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 5 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 5 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5 * rs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + // perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 4 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 4 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4 * rs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 3 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); // B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); // B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); // B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); // B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); // B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); // store B11[3][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); // store B11[6][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3 + 4), ymm7); // store B11[7][0-3] + n_remainder -= 4; + } + + if (n_remainder) // implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); + b01 = B + i + d_mr; + b11 = B + i; + + k_iter = (m - i - d_mr); + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx3n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); + // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); + // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + ymm6 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2 + 4)); + // B11[0][6] B11[1][6] B11[2][6] B11[3][6] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_fmsub_pd(ymm6, ymm16, ymm14); // B11[0-3][6] * alpha -= B01[0-3][6] + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (2 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx2n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + ymm5 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1 + 4)); // B11[0][5] B11[1][5] B11[2][5] B11[3][5] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_fmsub_pd(ymm5, ymm16, ymm13); // B11[0-3][5] * alpha -= B01[0-3][5] + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (1 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_8mx1n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm4 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0 + 4)); // B11[0][4] B11[1][4] B11[2][4] B11[3][4] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + + ymm4 = _mm256_fmsub_pd(ymm4, ymm16, ymm12); // B11[0-3][4] * alpha -= B01[0-3][4] + ymm5 = _mm256_broadcast_sd((double const *)(&ones)); + ymm6 = _mm256_broadcast_sd((double const *)(&ones)); + ymm7 = _mm256_broadcast_sd((double const *)(&ones)); + } + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + ymm13 = _mm256_unpacklo_pd(ymm4, ymm5); // B11[0][4] B11[0][5] B11[2][4] B11[2][5] + ymm15 = _mm256_unpacklo_pd(ymm6, ymm7); // B11[0][6] B11[0][7] B11[2][6] B11[2][7] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ymm12 = _mm256_permute2f128_pd(ymm13, ymm15, 0x20); // B11[4][0] B11[4][1] B11[4][2] B11[4][3] + ymm14 = _mm256_permute2f128_pd(ymm13, ymm15, 0x31); // B11[6][0] B11[6][1] B11[6][2] B11[6][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + ymm4 = _mm256_unpackhi_pd(ymm4, ymm5); // B11[1][4] B11[1][5] B11[3][4] B11[3][5] + ymm5 = _mm256_unpackhi_pd(ymm6, ymm7); // B11[1][6] B11[1][7] B11[3][6] B11[3][7] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + ymm13 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); // B11[5][0] B11[5][1] B11[5][2] B11[5][3] + ymm15 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); // B11[7][0] B11[7][1] B11[7][2] B11[7][3] + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 7)); + + // perform mul operation + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 6)); + + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 6 * cs_a + 7 * rs_a)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 * cs_a + 7 * rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 7 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 7 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 7 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 7 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 7 * rs_a)); + + //(ROw7): FMA operations + ymm14 = _mm256_fnmadd_pd(ymm2, ymm15, ymm14); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm15, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm15, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm15, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm15, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm15, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm15, ymm8); + + // perform mul operation + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 5)); + + ymm3 = _mm256_broadcast_sd((double const *)(a11 + 5 * cs_a + 6 * rs_a)); + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 6 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 6 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 6 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 6 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 6 * rs_a)); + + //(ROw6): FMA operations + ymm13 = _mm256_fnmadd_pd(ymm3, ymm14, ymm13); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm14, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm14, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm14, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm14, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm14, ymm8); + + // perform mul operation + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 4)); + + ymm4 = _mm256_broadcast_sd((double const *)(a11 + 4 * cs_a + 5 * rs_a)); + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 5 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 5 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 5 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 5 * rs_a)); + + //(ROw5): FMA operations + ymm12 = _mm256_fnmadd_pd(ymm4, ymm13, ymm12); + ymm11 = _mm256_fnmadd_pd(ymm5, ymm13, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm13, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm13, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm13, ymm8); + + // perform mul operation + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + ymm5 = _mm256_broadcast_sd((double const *)(a11 + 3 * cs_a + 4 * rs_a)); + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 4 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 4 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 4 * rs_a)); + + //(ROw4): FMA operations + ymm11 = _mm256_fnmadd_pd(ymm5, ymm12, ymm11); + ymm10 = _mm256_fnmadd_pd(ymm6, ymm12, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm12, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm12, ymm8); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 3 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); // B11[4][0] B11[5][0] B11[4][2] B11[5][2] + ymm7 = _mm256_unpacklo_pd(ymm14, ymm15); // B11[6][0] B11[7][0] B11[6][2] B11[7][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm4 = _mm256_permute2f128_pd(ymm5, ymm7, 0x20); // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + ymm6 = _mm256_permute2f128_pd(ymm5, ymm7, 0x31); // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + ymm12 = _mm256_unpackhi_pd(ymm12, ymm13); // B11[4][1] B11[5][1] B11[4][3] B11[5][3] + ymm13 = _mm256_unpackhi_pd(ymm14, ymm15); // B11[6][1] B11[7][1] B11[6][3] B11[7][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + ymm5 = _mm256_permute2f128_pd(ymm12, ymm13, 0x20); // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + ymm7 = _mm256_permute2f128_pd(ymm12, ymm13, 0x31); // B11[4][3] B11[5][3] B11[6][3] B11[7][3] + + if (3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2 + 4), ymm6); // store B11[6][0-3] + } + else if (2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1 + 4), ymm5); // store B11[5][0-3] + } + else if (1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 0 + 4), ymm4); // store B11[4][0-3] + } + } + } + dim_t m_remainder = i + d_mr; + + if (m_remainder >= 4) + { + i = m_remainder - 4; + a10 = L + (i * cs_a) + (i + 4) * rs_a; // pointer to block of A to be used for GEMM + a11 = L + (i * cs_a) + (i * rs_a); // pointer to block of A to be used for TRSM + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < m - i - 4; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a10 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < m - i - 4; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x * rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x * p_lda), ymm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + bli_obj_col_stride(a) * 1 + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + bli_obj_col_stride(a) * 2 + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + bli_obj_col_stride(a) * 3 + 3)); + + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + // cols + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b) + i + 4; // pointer to block of B to be used for GEMM + b11 = B + (j * cs_b) + i; // pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8(b11, cs_b, AlphaVal) + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm1); + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a + 2 * cs_a)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a + 1 * cs_a)); + ymm11 = _mm256_fnmadd_pd(ymm0, ymm12, ymm11); + ymm15 = _mm256_fnmadd_pd(ymm0, ymm16, ymm15); + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a + 0 * cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm12, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm16, ymm14); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm12, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm16, ymm13); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm1); + + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a + 1 * cs_a)); + ymm0 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a + 0 * cs_a)); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm11, ymm10); + ymm14 = _mm256_fnmadd_pd(ymm1, ymm15, ymm14); + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + ymm9 = _mm256_fnmadd_pd(ymm0, ymm11, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm0, ymm15, ymm13); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a + 0 * cs_a)); + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 0)); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm10, ymm9); + ymm13 = _mm256_fnmadd_pd(ymm1, ymm14, ymm13); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + + + BLIS_DTRSM_SMALL_NREG_TRANSPOSE_4x8_AND_STORE(b11, cs_b) + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b * 6), ymm6); + _mm256_storeu_pd((double *)(b11 + cs_b * 7), ymm7); + } + dim_t n_remainder = j + d_nr; + if ((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); // pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4) * cs_b) + i + 4; // pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4) * cs_b) + i; // pointer to block of B to be used for TRSM + + k_iter = (m - i - 4); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm3 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(ROw3): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 3 * rs_a)); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm11, ymm10); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3 * rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm11, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm11, ymm8); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2 * rs_a)); + ymm9 = _mm256_fnmadd_pd(ymm2, ymm10, ymm9); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm10, ymm8); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + //(ROw2): FMA operations + ymm2 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a)); + ymm8 = _mm256_fnmadd_pd(ymm2, ymm9, ymm8); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); // store B11[3][0-3] + n_remainder = n_remainder - 4; + } + + if (n_remainder) // implementation fo remaining columns(when 'N' is not a multiple of d_nr)() n = 3 + { + a10 = D_A_pack; + a11 = L + (i * cs_a) + (i * rs_a); + b01 = B + i + 4; + b11 = B + i; + + k_iter = (m - i - 4); + + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + + if (3 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm2 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); // B11[0-3][2] * alpha -= B01[0-3][2] + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (2 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 1)); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); // B11[0-3][1] * alpha -= B01[0-3][1] + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + else if (1 == n_remainder) + { + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 0)); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + + ymm0 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); // B11[0-3][0] * alpha -= B01[0-3][0] + ymm1 = _mm256_broadcast_sd((double const *)(&ones)); + ymm2 = _mm256_broadcast_sd((double const *)(&ones)); + ymm3 = _mm256_broadcast_sd((double const *)(&ones)); + } + + /// implement TRSM/// + + /// transpose of B11// + /// unpacklow/// + ymm9 = _mm256_unpacklo_pd(ymm0, ymm1); // B11[0][0] B11[0][1] B11[2][0] B11[2][1] + ymm11 = _mm256_unpacklo_pd(ymm2, ymm3); // B11[0][2] B11[0][3] B11[2][2] B11[2][3] + + // rearrange low elements + ymm8 = _mm256_permute2f128_pd(ymm9, ymm11, 0x20); // B11[0][0] B11[0][1] B11[0][2] B11[0][3] + ymm10 = _mm256_permute2f128_pd(ymm9, ymm11, 0x31); // B11[2][0] B11[2][1] B11[2][2] B11[2][3] + + ////unpackhigh//// + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); // B11[1][0] B11[1][1] B11[3][0] B11[3][1] + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); // B11[1][2] B11[1][3] B11[3][2] B11[3][3] + + // rearrange high elements + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); // B11[1][0] B11[1][1] B11[1][2] B11[1][3] + ymm11 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); // B11[3][0] B11[3][1] B11[3][2] B11[3][3] + + // extract a33 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + // perform mul operation + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm1); + + // extract a22 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + ymm6 = _mm256_broadcast_sd((double const *)(a11 + 2 * cs_a + 3 * rs_a)); + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 3 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 3 * rs_a)); + + //(ROw3): FMA operations + ymm10 = _mm256_fnmadd_pd(ymm6, ymm11, ymm10); + ymm9 = _mm256_fnmadd_pd(ymm7, ymm11, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm11, ymm8); + + // perform mul operation + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + + // extract a11 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + ymm7 = _mm256_broadcast_sd((double const *)(a11 + cs_a + 2 * rs_a)); + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 2 * rs_a)); + + //(ROw2): FMA operations + ymm9 = _mm256_fnmadd_pd(ymm7, ymm10, ymm9); + ymm8 = _mm256_fnmadd_pd(ymm16, ymm10, ymm8); + + // perform mul operation + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm1); + + // extract a00 + ymm1 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm16 = _mm256_broadcast_sd((double const *)(a11 + 1 * rs_a)); + + //(ROw2): FMA operations + ymm8 = _mm256_fnmadd_pd(ymm16, ymm9, ymm8); + + // perform mul operation + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm1); + + // unpacklow// + ymm1 = _mm256_unpacklo_pd(ymm8, ymm9); // B11[0][0] B11[1][0] B11[0][2] B11[1][2] + ymm3 = _mm256_unpacklo_pd(ymm10, ymm11); // B11[2][0] B11[3][0] B11[2][2] B11[3][2] + + // rearrange low elements + ymm0 = _mm256_permute2f128_pd(ymm1, ymm3, 0x20); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm2 = _mm256_permute2f128_pd(ymm1, ymm3, 0x31); // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + + /// unpack high/// + ymm8 = _mm256_unpackhi_pd(ymm8, ymm9); // B11[0][1] B11[1][1] B11[0][3] B11[1][3] + ymm9 = _mm256_unpackhi_pd(ymm10, ymm11); // B11[2][1] B11[3][1] B11[2][3] B11[3][3] + + // rearrange high elements + ymm1 = _mm256_permute2f128_pd(ymm8, ymm9, 0x20); // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm3 = _mm256_permute2f128_pd(ymm8, ymm9, 0x31); // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + + if (3 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); // store B11[2][0-3] + } + else if (2 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); // store B11[1][0-3] + } + else if (1 == n_remainder) + { + _mm256_storeu_pd((double *)(b11 + cs_b * 0), ymm0); // store B11[0][0-3] + } + } + m_remainder -= 4; + } + if (m_remainder) + { + + a10 = L + m_remainder * rs_a; + + // Do transpose for a10 & store in D_A_pack + double *ptr_a10_dup = D_A_pack; + if (3 == m_remainder) // Repetative A blocks will be 3*3 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < m - m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a10 + cs_a * 2)); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < m - m_remainder; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x * rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x * p_lda), ymm0); + } + } + + // cols + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + ymm0 =_mm256_broadcast_sd((double const *)(&AlphaVal)); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 0 + 2)); + ymm1 = _mm256_insertf64x2(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 1 + 2)); + ymm2 = _mm256_insertf64x2(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 2 + 2)); + ymm3 = _mm256_insertf64x2(ymm3, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm4 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 3 + 2)); + ymm4 = _mm256_insertf64x2(ymm4, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 4)); + ymm5 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 4 + 2)); + ymm5 = _mm256_insertf64x2(ymm5, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 5)); + ymm6 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 5 + 2)); + ymm6 = _mm256_insertf64x2(ymm6, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 6)); + ymm7 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 6 + 2)); + ymm7 = _mm256_insertf64x2(ymm7, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 7)); + ymm8 =_mm256_broadcast_sd((double const *)(b11 + cs_b * 7 + 2)); + ymm8 = _mm256_insertf64x2(ymm8, xmm5, 0); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storeu_pd((double *)(b11 + cs_b * 0), _mm256_extractf64x2_pd(ymm9, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_extractf64x2_pd(ymm10, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_extractf64x2_pd(ymm11, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_extractf64x2_pd(ymm12, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_extractf64x2_pd(ymm13, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_extractf64x2_pd(ymm14, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 6), _mm256_extractf64x2_pd(ymm15, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 7), _mm256_extractf64x2_pd(ymm16, 0)); + + _mm_storel_pd((double *)(b11 + cs_b * 0 + 2), _mm256_extractf64x2_pd(ymm9, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf64x2_pd(ymm10, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf64x2_pd(ymm11, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf64x2_pd(ymm12, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 4 + 2), _mm256_extractf64x2_pd(ymm13, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 5 + 2), _mm256_extractf64x2_pd(ymm14, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 6 + 2), _mm256_extractf64x2_pd(ymm15, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 7 + 2), _mm256_extractf64x2_pd(ymm16, 1)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 8, rs_a, cs_b, is_unitdiag); + } + + dim_t n_remainder = j + d_nr; + if ((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4) * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4) * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0 + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1 + 2)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2 + 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3 + 2)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + _mm_storel_pd((double *)(b11 + 2), _mm256_extractf128_pd(ymm8, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm9, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm10, 1)); + _mm_storel_pd((double *)(b11 + cs_b * 3 + 2), _mm256_extractf128_pd(ymm11, 1)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + + if (n_remainder) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + m_remainder; // pointer to block of B to be used for GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if (2 == m_remainder) // Repetative A blocks will be 2*2 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < m - m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < m - m_remainder; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x * rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x * p_lda), ymm0); + } + } + // cols + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + + ymm0 = _mm256_broadcast_sd((double const *)(&AlphaVal)); + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_insertf64x2(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_insertf64x2(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_insertf64x2(ymm3, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm4 = _mm256_insertf64x2(ymm4, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 4)); + ymm5 = _mm256_insertf64x2(ymm5, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 5)); + ymm6 = _mm256_insertf64x2(ymm6, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 6)); + ymm7 = _mm256_insertf64x2(ymm7, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 7)); + ymm8 = _mm256_insertf64x2(ymm8, xmm5, 0); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storeu_pd((double *)(b11 + cs_b * 0), _mm256_extractf64x2_pd(ymm9, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_extractf64x2_pd(ymm10, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_extractf64x2_pd(ymm11, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_extractf64x2_pd(ymm12, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 4), _mm256_extractf64x2_pd(ymm13, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 5), _mm256_extractf64x2_pd(ymm14, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 6), _mm256_extractf64x2_pd(ymm15, 0)); + _mm_storeu_pd((double *)(b11 + cs_b * 7), _mm256_extractf64x2_pd(ymm16, 0)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 8, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if ((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4) * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4) * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 0)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0); + + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 3)); + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storeu_pd((double *)(b11), _mm256_castpd256_pd128(ymm8)); + _mm_storeu_pd((double *)(b11 + cs_b * 1), _mm256_castpd256_pd128(ymm9)); + _mm_storeu_pd((double *)(b11 + cs_b * 2), _mm256_castpd256_pd128(ymm10)); + _mm_storeu_pd((double *)(b11 + cs_b * 3), _mm256_castpd256_pd128(ymm11)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if (n_remainder) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + m_remainder; // pointer to block of B to be used for GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + if (3 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2M_1N(AlphaVal, b11, cs_b) + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + else if (1 == m_remainder) // Repetative A blocks will be 1*1 + { + dim_t p_lda = 4; // packed leading dimension + if (transa) + { + for (dim_t x = 0; x < m - m_remainder; x += p_lda) + { + ymm0 = _mm256_loadu_pd((double const *)(a10)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda * 3), ymm9); + + a10 += p_lda; + ptr_a10_dup += p_lda * p_lda; + } + } + else + { + for (dim_t x = 0; x < m - m_remainder; x++) + { + ymm0 = _mm256_loadu_pd((double const *)(a10 + x * rs_a)); + _mm256_storeu_pd((double *)(ptr_a10_dup + x * p_lda), ymm0); + } + } + // cols + for (j = (n - d_nr); (j + 1) > 0; j -= d_nr) // loop along 'N' dimension + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + (j * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + (j * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + BLIS_DTRSM_SMALL_GEMM_4mx8n(a10, b01, cs_b, p_lda, k_iter) + + /// GEMM code ends/// + ymm0 = _mm256_broadcast_sd((double const *)(&AlphaVal)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2)); + ymm4 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3)); + ymm5 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 4)); + ymm6 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 5)); + ymm7 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 6)); + ymm8 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 7)); + + ymm9 = _mm256_fmsub_pd(ymm1, ymm0, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm0, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm0, ymm11); + ymm12 = _mm256_fmsub_pd(ymm4, ymm0, ymm12); + ymm13 = _mm256_fmsub_pd(ymm5, ymm0, ymm13); + ymm14 = _mm256_fmsub_pd(ymm6, ymm0, ymm14); + ymm15 = _mm256_fmsub_pd(ymm7, ymm0, ymm15); + ymm16 = _mm256_fmsub_pd(ymm8, ymm0, ymm16); + + _mm_storel_pd((double *)(b11 + cs_b * 0), _mm256_extractf64x2_pd(ymm9, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf64x2_pd(ymm10, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf64x2_pd(ymm11, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf64x2_pd(ymm12, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 4), _mm256_extractf64x2_pd(ymm13, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 5), _mm256_extractf64x2_pd(ymm14, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 6), _mm256_extractf64x2_pd(ymm15, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 7), _mm256_extractf64x2_pd(ymm16, 0)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 8, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 8, rs_a, cs_b, is_unitdiag); + } + dim_t n_remainder = j + d_nr; + if ((n_remainder >= 4)) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + ((n_remainder - 4) * cs_b) + m_remainder; // pointer to block of B to be used for GEMM + b11 = B + ((n_remainder - 4) * cs_b); // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + + BLIS_DTRSM_SMALL_GEMM_4mx4n(a10, b01, cs_b, p_lda, k_iter) + + ymm16 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + /// implement TRSM/// + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 0)); + ymm1 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1)); + ymm2 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2)); + ymm3 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 3)); + + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm8); + ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm9); + ymm10 = _mm256_fmsub_pd(ymm2, ymm16, ymm10); + ymm11 = _mm256_fmsub_pd(ymm3, ymm16, ymm11); + + _mm_storel_pd((double *)(b11), _mm256_extractf128_pd(ymm8, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 1), _mm256_extractf128_pd(ymm9, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 2), _mm256_extractf128_pd(ymm10, 0)); + _mm_storel_pd((double *)(b11 + cs_b * 3), _mm256_extractf128_pd(ymm11, 0)); + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 4, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 4, rs_a, cs_b, is_unitdiag); + n_remainder -= 4; + } + if (n_remainder) + { + a10 = D_A_pack; + a11 = L; // pointer to block of A to be used for TRSM + b01 = B + m_remainder; // pointer to block of B to be used for GEMM + b11 = B; // pointer to block of B to be used for TRSM + + k_iter = (m - m_remainder); // number of times GEMM to be performed(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + BLIS_SET_YMM_REG_ZEROS_FOR_LEFT + + if (3 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx3n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_3N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 3, cs_a, cs_b, is_unitdiag); + else + dtrsm_AuXB_ref(a11, b11, m_remainder, 3, rs_a, cs_b, is_unitdiag); + } + else if (2 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx2n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_2N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 2, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 2, rs_a, cs_b, is_unitdiag); + } + else if (1 == n_remainder) + { + /// GEMM code begins/// + BLIS_DTRSM_SMALL_GEMM_4mx1n(a10, b01, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1M_1N(AlphaVal, b11, cs_b) + + if (transa) + dtrsm_AltXB_ref(a11, b11, m_remainder, 1, cs_a, cs_b, is_unitdiag); + else dtrsm_AuXB_ref(a11, b11, m_remainder, 1, rs_a, cs_b, is_unitdiag); + } + } + } + } + + if ((required_packing_A == 1) && + bli_mem_is_alloc(&local_mem_buf_A_s)) + { + bli_membrk_release(&rntm, &local_mem_buf_A_s); + } + return BLIS_SUCCESS; } #endif \ No newline at end of file