From 18569b42eee4235d73c6f31e4d37c5644f296f69 Mon Sep 17 00:00:00 2001 From: Shubham Date: Tue, 17 Jan 2023 22:59:35 +0530 Subject: [PATCH] Added DTRSM small RUNN/RLTN variant AVX512 kernels - 8x8 kernels are used for DTRSM SMALL - Matrix A(a10) is packed for GEMM operations. - Packed martix A will be re-used in all the col-block along N-dimension. - Diagonal elements of A matrix are packed(a11) for TRSM operations. - Implemented fringe cases with following block sizes 8x8, 8x4, 8x3, 8x2, 8x1 4x8, 4x4, 4x3, 4x2, 4x1 3x8, 3x4, 3x3, 3x2, 3x1 2x8, 2x4, 2x3, 2x2, 2x1 1x8, 1x4, 1x3, 1x2, 1x1 AMD-Internal: [CPUPL-2745] Change-Id: I6a174e7f88a4c2c5778052525879552a1e82f6ad --- CMakeLists.txt | 1 + frame/3/bli_l3_sup_ker_prot.h | 28 +- frame/compat/bla_trsm_amd.c | 70 +- kernels/zen4/3/CMakeLists.txt | 1 + kernels/zen4/3/bli_trsm_small_AVX512.c | 4333 ++++++++++++++++++++++++ kernels/zen4/bli_kernels_zen4.h | 7 + 6 files changed, 4430 insertions(+), 10 deletions(-) create mode 100644 kernels/zen4/3/bli_trsm_small_AVX512.c diff --git a/CMakeLists.txt b/CMakeLists.txt index d91ff7594..34b5b6a08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -342,6 +342,7 @@ if(${TARGET_ARCH} STREQUAL zen4 OR set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/3/bli_gemmtrsm_u_zen_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_dgemm_skx_asm_16x14.c PROPERTIES COMPILE_FLAGS /arch:AVX512) set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/skx/3/bli_sgemm_skx_asm_32x12_l2.c PROPERTIES COMPILE_FLAGS /arch:AVX512) + set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/kernels/zen4/3/bli_trsm_small_AVX512.c PROPERTIES COMPILE_FLAGS /arch:AVX512) endif() set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} /W0 ") diff --git a/frame/3/bli_l3_sup_ker_prot.h b/frame/3/bli_l3_sup_ker_prot.h index 899a47d3f..afa8f8ace 100644 --- a/frame/3/bli_l3_sup_ker_prot.h +++ b/frame/3/bli_l3_sup_ker_prot.h @@ -4,7 +4,7 @@ An object-based framework for developing high-performance BLAS-like libraries. - Copyright (C) 2019, Advanced Micro Devices, Inc. + Copyright (C) 2019-2023, Advanced Micro Devices, Inc. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -54,3 +54,29 @@ void PASTEMAC(ch,opname) \ cntx_t* restrict cntx \ ); + + +#define TRSMSUP_PROT( opname ) \ +\ +err_t PASTEMAC0(opname) \ + ( \ + side_t side, \ + obj_t* alpha, \ + obj_t* a, \ + obj_t* b, \ + cntx_t* cntx, \ + cntl_t* cntl, \ + bool is_parallel \ + ); + +#define TRSMSUP_KER_PROT( ch, opname ) \ +\ +BLIS_INLINE err_t PASTEMAC(ch,opname) \ + ( \ + obj_t* AlphaObj, \ + obj_t* a, \ + obj_t* b, \ + cntx_t* cntx, \ + cntl_t* cntl \ + ); + diff --git a/frame/compat/bla_trsm_amd.c b/frame/compat/bla_trsm_amd.c index 93b9168ab..b5f295854 100644 --- a/frame/compat/bla_trsm_amd.c +++ b/frame/compat/bla_trsm_amd.c @@ -5,7 +5,7 @@ libraries. Copyright (C) 2014, The University of Texas at Austin - Copyright (C) 2019 - 2022, Advanced Micro Devices, Inc. All rights reserved. + Copyright (C) 2019 - 2023, Advanced Micro Devices, Inc. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -953,14 +953,66 @@ void dtrsm_blis_impl (is_parallel && (m0+n0)<320)) { err_t status; - status = bli_trsm_small( - blis_side, - &alphao, - &ao, - &bo, - NULL, - NULL, - is_parallel); + + // Query the architecture ID + arch_t id = bli_arch_query_id(); +#if defined(BLIS_KERNELS_ZEN4) + bool uplo, transa; +#endif + switch(id) + { + case BLIS_ARCH_ZEN4: +#if defined(BLIS_KERNELS_ZEN4) + // check if variant is RUN[N/U] or RLT[N/U] + // this is a temporary fix, will be removed when all variants are added + + // for n < 200 avx2 kernels are performing better, but if + // n is a multiple of 8 then there will be no fringe case for avx512, + // in such cases avx512 kernels will perform better. + uplo = bli_obj_is_upper(&ao); + transa = bli_obj_has_trans(&ao); + if(( ((blis_side == BLIS_RIGHT) && (uplo == true) && (transa == false)) || + ((blis_side == BLIS_RIGHT) && (uplo == false) && (transa == true))) && + ((n0 > 400) && (m0 > 50))) + { + status = bli_trsm_small_AVX512( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL, + is_parallel); + } + else + { + status = bli_trsm_small( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL, + is_parallel); + } + break; +#endif + + case BLIS_ARCH_ZEN: + case BLIS_ARCH_ZEN2: + case BLIS_ARCH_ZEN3: + status = bli_trsm_small( + blis_side, + &alphao, + &ao, + &bo, + NULL, + NULL, + is_parallel); + break; + default: + status = BLIS_NOT_YET_IMPLEMENTED; + } if (status == BLIS_SUCCESS) { AOCL_DTL_TRACE_EXIT(AOCL_DTL_LEVEL_INFO); diff --git a/kernels/zen4/3/CMakeLists.txt b/kernels/zen4/3/CMakeLists.txt index e609f56b0..82724c763 100644 --- a/kernels/zen4/3/CMakeLists.txt +++ b/kernels/zen4/3/CMakeLists.txt @@ -5,4 +5,5 @@ target_sources("${PROJECT_NAME}" ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_l_zen_16x14.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_gemmtrsm_u_zen_16x14.c ${CMAKE_CURRENT_SOURCE_DIR}/bli_dgemm_zen4_asm_32x6.c + ${CMAKE_CURRENT_SOURCE_DIR}/bli_trsm_small_AVX512.c ) diff --git a/kernels/zen4/3/bli_trsm_small_AVX512.c b/kernels/zen4/3/bli_trsm_small_AVX512.c new file mode 100644 index 000000000..9912c8e4a --- /dev/null +++ b/kernels/zen4/3/bli_trsm_small_AVX512.c @@ -0,0 +1,4333 @@ +/* + BLIS + An object-based framework for developing high-performance BLAS-like + libraries. + Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are + met: + - Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + - Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + - Neither the name(s) of the copyright holder(s) nor the names of its + contributors may be used to endorse or promote products derived + from this software without specific prior written permission. + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS + "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT + LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR + A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT + HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT + LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, + DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY + THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ +#include "blis.h" + +#ifdef BLIS_ENABLE_SMALL_MATRIX_TRSM +#include "immintrin.h" +#define BLIS_ENABLE_PREFETCH_IN_TRSM_SMALL + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + #define DIAG_ELE_INV_OPS(a, b) (a / b) + #define DIAG_ELE_EVAL_OPS(a, b) (a * b) +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + #define DIAG_ELE_INV_OPS(a, b) (a * b) + #define DIAG_ELE_EVAL_OPS(a, b) (a / b) +#endif + +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + #define DTRSM_SMALL_DIV_OR_SCALE _mm256_div_pd + #define DTRSM_SMALL_DIV_OR_SCALE_AVX512 _mm512_div_pd +#endif + +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + #define DTRSM_SMALL_DIV_OR_SCALE _mm256_mul_pd + #define DTRSM_SMALL_DIV_OR_SCALE_AVX512 _mm512_mul_pd +#endif + +#define BLIS_SET_YMM_REG_ZEROS_AVX512 \ + ymm0 = _mm256_setzero_pd(); \ + ymm1 = _mm256_setzero_pd(); \ + ymm2 = _mm256_setzero_pd(); \ + ymm3 = _mm256_setzero_pd(); \ + ymm4 = _mm256_setzero_pd(); \ + ymm5 = _mm256_setzero_pd(); \ + ymm6 = _mm256_setzero_pd(); \ + ymm7 = _mm256_setzero_pd(); \ + ymm8 = _mm256_setzero_pd(); \ + ymm9 = _mm256_setzero_pd(); \ + ymm10 = _mm256_setzero_pd(); \ + ymm11 = _mm256_setzero_pd(); \ + ymm12 = _mm256_setzero_pd(); \ + ymm13 = _mm256_setzero_pd(); \ + ymm14 = _mm256_setzero_pd(); \ + ymm15 = _mm256_setzero_pd(); \ + ymm16 = _mm256_setzero_pd(); \ + ymm17 = _mm256_setzero_pd(); \ + ymm18 = _mm256_setzero_pd(); \ + ymm19 = _mm256_setzero_pd(); \ + ymm20 = _mm256_setzero_pd(); \ + ymm21 = _mm256_setzero_pd(); \ + ymm22 = _mm256_setzero_pd(); \ + ymm23 = _mm256_setzero_pd(); \ + ymm24 = _mm256_setzero_pd(); \ + ymm25 = _mm256_setzero_pd(); \ + ymm26 = _mm256_setzero_pd(); \ + ymm27 = _mm256_setzero_pd(); \ + ymm28 = _mm256_setzero_pd(); \ + ymm29 = _mm256_setzero_pd(); \ + ymm30 = _mm256_setzero_pd(); \ + ymm31 = _mm256_setzero_pd(); + +#define BLIS_SET_ZMM_REG_ZEROS \ + zmm0 = _mm512_setzero_pd(); \ + zmm1 = _mm512_setzero_pd(); \ + zmm2 = _mm512_setzero_pd(); \ + zmm3 = _mm512_setzero_pd(); \ + zmm4 = _mm512_setzero_pd(); \ + zmm5 = _mm512_setzero_pd(); \ + zmm6 = _mm512_setzero_pd(); \ + zmm7 = _mm512_setzero_pd(); \ + zmm8 = _mm512_setzero_pd(); \ + zmm9 = _mm512_setzero_pd(); \ + zmm10 = _mm512_setzero_pd(); \ + zmm11 = _mm512_setzero_pd(); \ + zmm12 = _mm512_setzero_pd(); \ + zmm13 = _mm512_setzero_pd(); \ + zmm14 = _mm512_setzero_pd(); \ + zmm15 = _mm512_setzero_pd(); \ + zmm16 = _mm512_setzero_pd(); \ + zmm17 = _mm512_setzero_pd(); \ + zmm18 = _mm512_setzero_pd(); \ + zmm19 = _mm512_setzero_pd(); \ + zmm20 = _mm512_setzero_pd(); \ + zmm21 = _mm512_setzero_pd(); \ + zmm22 = _mm512_setzero_pd(); \ + zmm23 = _mm512_setzero_pd(); \ + zmm24 = _mm512_setzero_pd(); \ + zmm25 = _mm512_setzero_pd(); \ + zmm26 = _mm512_setzero_pd(); \ + zmm27 = _mm512_setzero_pd(); \ + zmm28 = _mm512_setzero_pd(); \ + zmm29 = _mm512_setzero_pd(); \ + zmm30 = _mm512_setzero_pd(); \ + zmm31 = _mm512_setzero_pd(); + +/* + declaration of trsm small kernels function pointer +*/ +typedef err_t (*trsmsmall_ker_ft) + ( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ); + +/* + Pack a block of 8xk from input buffer into packed buffer + directly or after transpose based on input params +*/ +BLIS_INLINE void bli_dtrsm_small_pack_avx512 + ( + char side, + dim_t size, + bool trans, + double* inbuf, + dim_t cs_a, + double* pbuff, + dim_t p_lda, + dim_t mr + ) +{ + // scratch registers + __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7; + if (side == 'L' || side == 'l') + { + return; // BLIS_NOT_YET_IMPLEMENTED + } + else if (side == 'R' || side == 'r') + { + if (trans) + { + /* + ---------------- ------------- + | | | | | + | 4x8 | | | | + ------------- ==> | 8x4 | 8x4 | + | 4x8 | | | | + | | | | | + ---------------- ------------- + */ + __m256d ymm0, ymm1, ymm2, ymm3; + __m256d ymm4, ymm5, ymm6, ymm7; + __m256d ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13; + for (dim_t x = 0; x < p_lda; x += mr) + { + // load 4x8 + ymm0 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 0))); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 1))); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 2))); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 3))); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + 4 + (cs_a * 0))); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + 4 + (cs_a * 1))); + ymm12 = _mm256_loadu_pd((double const *)(inbuf + 4 + (cs_a * 2))); + ymm13 = _mm256_loadu_pd((double const *)(inbuf + 4 + (cs_a * 3))); + + // transpose 4x4 + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + // store 4x4 + _mm256_storeu_pd((double *)(pbuff + (p_lda * 0)), ymm6); + _mm256_storeu_pd((double *)(pbuff + (p_lda * 1)), ymm7); + _mm256_storeu_pd((double *)(pbuff + (p_lda * 2)), ymm8); + _mm256_storeu_pd((double *)(pbuff + (p_lda * 3)), ymm9); + + // transpose 4x4 + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + // store 4x4 + _mm256_storeu_pd((double *)(pbuff + (p_lda * 4)), ymm6); + _mm256_storeu_pd((double *)(pbuff + (p_lda * 5)), ymm7); + _mm256_storeu_pd((double *)(pbuff + (p_lda * 6)), ymm8); + _mm256_storeu_pd((double *)(pbuff + (p_lda * 7)), ymm9); + + // load 4x8 + ymm0 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 4))); + ymm1 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 5))); + ymm2 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 6))); + ymm3 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 7))); + ymm10 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 4) + 4)); + ymm11 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 5) + 4)); + ymm12 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 6) + 4)); + ymm13 = _mm256_loadu_pd((double const *)(inbuf + (cs_a * 7) + 4)); + + // transpose 4x4 + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + // store 4x4 + _mm256_storeu_pd((double *)(pbuff + 4 + (p_lda * 0)), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + (p_lda * 1)), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + (p_lda * 2)), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + (p_lda * 3)), ymm9); + + // transpose 4x4 + ymm4 = _mm256_unpacklo_pd(ymm10, ymm11); + ymm5 = _mm256_unpacklo_pd(ymm12, ymm13); + ymm6 = _mm256_permute2f128_pd(ymm4, ymm5, 0x20); + ymm8 = _mm256_permute2f128_pd(ymm4, ymm5, 0x31); + ymm0 = _mm256_unpackhi_pd(ymm10, ymm11); + ymm1 = _mm256_unpackhi_pd(ymm12, ymm13); + ymm7 = _mm256_permute2f128_pd(ymm0, ymm1, 0x20); + ymm9 = _mm256_permute2f128_pd(ymm0, ymm1, 0x31); + + // store 4x4 + _mm256_storeu_pd((double *)(pbuff + 4 + (p_lda * 4)), ymm6); + _mm256_storeu_pd((double *)(pbuff + 4 + (p_lda * 5)), ymm7); + _mm256_storeu_pd((double *)(pbuff + 4 + (p_lda * 6)), ymm8); + _mm256_storeu_pd((double *)(pbuff + 4 + (p_lda * 7)), ymm9); + + inbuf += mr * cs_a; + pbuff += mr; + } + } + else + { + dim_t size_div_8 = size >> 3; + for (int i = 0; i < size_div_8; i++) + { + zmm0 = _mm512_loadu_pd((double const *)(inbuf + (cs_a * 0))); + _mm512_storeu_pd((double *)(pbuff + (p_lda * 0)), zmm0); + zmm1 = _mm512_loadu_pd((double const *)(inbuf + (cs_a * 1))); + _mm512_storeu_pd((double *)(pbuff + (p_lda * 1)), zmm1); + zmm2 = _mm512_loadu_pd((double const *)(inbuf + (cs_a * 2))); + _mm512_storeu_pd((double *)(pbuff + (p_lda * 2)), zmm2); + zmm3 = _mm512_loadu_pd((double const *)(inbuf + (cs_a * 3))); + _mm512_storeu_pd((double *)(pbuff + (p_lda * 3)), zmm3); + zmm4 = _mm512_loadu_pd((double const *)(inbuf + (cs_a * 4))); + _mm512_storeu_pd((double *)(pbuff + (p_lda * 4)), zmm4); + zmm5 = _mm512_loadu_pd((double const *)(inbuf + (cs_a * 5))); + _mm512_storeu_pd((double *)(pbuff + (p_lda * 5)), zmm5); + zmm6 = _mm512_loadu_pd((double const *)(inbuf + (cs_a * 6))); + _mm512_storeu_pd((double *)(pbuff + (p_lda * 6)), zmm6); + zmm7 = _mm512_loadu_pd((double const *)(inbuf + (cs_a * 7))); + _mm512_storeu_pd((double *)(pbuff + (p_lda * 7)), zmm7); + inbuf += 8; + pbuff += 8; + } + } + } +} +/* + Pack diagonal elements of A block (8) into an array + a. This helps in utilze cache line efficiently in TRSM operation + b. store ones when input is unit diagonal +*/ +BLIS_INLINE void dtrsm_small_pack_diag_element_avx512 + ( + bool is_unitdiag, + double* a11, + dim_t cs_a, + double* d11_pack, + dim_t size + ) +{ + __m512d zmm0, zmm1, zmm2, zmm3; + __m512d zmm4, zmm5, zmm6, zmm7; + __m512d zmm8; + double ones = 1.0; + // if (size == 8) + { + zmm8 = _mm512_set1_pd(ones); + if (!is_unitdiag) + { + __m512d zmm10, zmm11, zmm12, zmm13; + __m512d zmm14, zmm15; + // broadcast diagonal elements of A11 + zmm0 = _mm512_set1_pd(*(a11 + (cs_a * 0) + 0)); + zmm1 = _mm512_set1_pd(*(a11 + (cs_a * 1) + 1)); + zmm2 = _mm512_set1_pd(*(a11 + (cs_a * 2) + 2)); + zmm3 = _mm512_set1_pd(*(a11 + (cs_a * 3) + 3)); + zmm4 = _mm512_set1_pd(*(a11 + (cs_a * 4) + 4)); + zmm5 = _mm512_set1_pd(*(a11 + (cs_a * 5) + 5)); + zmm6 = _mm512_set1_pd(*(a11 + (cs_a * 6) + 6)); + zmm7 = _mm512_set1_pd(*(a11 + (cs_a * 7) + 7)); + + //combine all elements of A11 into zmm1 + // Stage 1 + zmm10 = _mm512_unpacklo_pd(zmm0, zmm1); + zmm11 = _mm512_unpacklo_pd(zmm2, zmm3); + zmm12 = _mm512_unpacklo_pd(zmm4, zmm5); + zmm13 = _mm512_unpacklo_pd(zmm6, zmm7); + // Stage 2 + zmm14 = _mm512_shuffle_f64x2(zmm10, zmm11, 0b10001000); + zmm15 = _mm512_shuffle_f64x2(zmm12, zmm13, 0b10001000); + // Stage 3 + zmm1 = _mm512_shuffle_f64x2(zmm14, zmm15, 0b10001000); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + zmm8 = zmm1; + #endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + zmm8 = _mm512_div_pd(zmm8, zmm1); + #endif + } + _mm512_storeu_pd((double *)(d11_pack), zmm8); + } +} +/* + * Kernels Table + */ +trsmsmall_ker_ft ker_fps_AVX512[4][8] = + { + {NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL}, + {NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL}, + {bli_dtrsm_small_AutXB_AlXB_AVX512, + bli_dtrsm_small_AltXB_AuXB_AVX512, + bli_dtrsm_small_AltXB_AuXB_AVX512, + bli_dtrsm_small_AutXB_AlXB_AVX512, + bli_dtrsm_small_XAutB_XAlB_AVX512, + bli_dtrsm_small_XAltB_XAuB_AVX512, + bli_dtrsm_small_XAltB_XAuB_AVX512, + bli_dtrsm_small_XAutB_XAlB_AVX512}, + {NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL, + NULL}, +}; +/* +* The bli_trsm_small implements a version of TRSM where A is packed and reused +* +* Input: A: MxM (triangular matrix) +* B: MxN matrix +* Output: X: MxN matrix such that + AX = alpha*B or XA = alpha*B or A'X = alpha*B or XA' = alpha*B +* Here the output X is stored in B +* +* Note: Currently only dtrsm is supported when A & B are column-major +*/ +err_t bli_trsm_small_AVX512 + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl, + bool is_parallel + ) +{ + err_t err; + dim_t m = bli_obj_length(b); + dim_t n = bli_obj_width(b); + + bool uplo = bli_obj_is_upper(a); + bool transa = bli_obj_has_trans(a); + num_t dt = bli_obj_dt(a); + + switch (dt) + { + case BLIS_DOUBLE: + { + if ((!is_parallel) && (m > 1200 || n > 1200)) + { + return BLIS_NOT_YET_IMPLEMENTED; + } + break; + } + case BLIS_FLOAT: + case BLIS_SCOMPLEX: + case BLIS_DCOMPLEX: + default: + { + return BLIS_NOT_YET_IMPLEMENTED; + break; + } + } + /* If alpha is zero, B matrix will become zero after scaling + hence solution is also zero matrix */ + if (bli_obj_equals(alpha, &BLIS_ZERO)) + { + return BLIS_NOT_YET_IMPLEMENTED; // scale B by alpha + } + + // Return if inputs are row major as currently + // we are supporing col major only + if ((bli_obj_row_stride(a) != 1) || + (bli_obj_row_stride(b) != 1)) + { + return BLIS_INVALID_ROW_STRIDE; + } + + // A is expected to be triangular in trsm + if (!bli_obj_is_upper_or_lower(a)) + { + return BLIS_EXPECTED_TRIANGULAR_OBJECT; + } + /* + * Compose kernel index based on inputs + */ + dim_t keridx = (((side & 0x1) << 2) | + ((uplo & 0x1) << 1) | + (transa & 0x1)); + trsmsmall_ker_ft ker_fp = ker_fps_AVX512[dt][keridx]; + /*Call the kernel*/ + err = ker_fp( + alpha, + a, + b, + cntx, + cntl); + return err; +}; + +#ifdef BLIS_ENABLE_OPENMP +/* + * Parallelized dtrsm_small across m-dimension or n-dimension based on side(Left/Right) + */ +err_t bli_trsm_small_mt_AVX512 + ( + side_t side, + obj_t* alpha, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + gint_t m = bli_obj_length(b); // number of rows of matrix b + gint_t n = bli_obj_width(b); // number of columns of Matrix b + dim_t d_mr = 8,d_nr = 8; + + num_t dt = bli_obj_dt(a); + switch (dt) + { + case BLIS_DOUBLE: + { + d_mr = 8, d_nr = 8; + break; + } + default: + { + return BLIS_NOT_YET_IMPLEMENTED; + break; + } + } + + rntm_t rntm; + bli_rntm_init_from_global(&rntm); + +#ifdef AOCL_DYNAMIC + // If dynamic-threading is enabled, calculate optimum number + // of threads. + // rntm will be updated with optimum number of threads. + if (bli_obj_is_double(b)) + { + bli_nthreads_optimum(a, b, b, BLIS_TRSM, &rntm); + } +#endif + + // Query the total number of threads from the rntm_t object. + dim_t n_threads = bli_rntm_num_threads(&rntm); + + if (n_threads < 0) + n_threads = 1; + + bool is_parallel = bli_thread_get_is_parallel(); + + err_t status = BLIS_SUCCESS; + _Pragma("omp parallel num_threads(n_threads)") + { + // Query the thread's id from OpenMP. + const dim_t tid = omp_get_thread_num(); + const dim_t nt_real = omp_get_num_threads(); + + // if num threads requested and num thread available + // is not same then use single thread small + if(nt_real != n_threads) + { + if(tid == 0) + { + bli_trsm_small + ( + side, + alpha, + a, + b, + cntx, + cntl, + is_parallel + ); + } + } + else + { + obj_t b_t; + dim_t start; // Each thread start Index + dim_t end; // Each thread end Index + thrinfo_t thread; + + thread.n_way = n_threads; + thread.work_id = tid; + thread.ocomm_id = tid; + + // Compute start and end indexes of matrix partitioning for each thread + if (bli_is_right(side)) + { + bli_thread_range_sub + ( + &thread, + m, + d_mr, // Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_mdim(BLIS_FWD, BLIS_SUBPART1, start, end - start, b, &b_t); + } + else + { + bli_thread_range_sub + ( + &thread, + n, + d_nr,// Need to decide based on type + FALSE, + &start, + &end + ); + // For each thread acquire matrix block on which they operate + // Data-based parallelism + + bli_acquire_mpart_ndim(BLIS_FWD, BLIS_SUBPART1, start, end - start, b, &b_t); + } + + // Parallelism is only across m-dimension/n-dimension - therefore matrix a is common to + // all threads + err_t status_l = BLIS_SUCCESS; + + status_l = bli_trsm_small_AVX512 + ( + side, + alpha, + a, + &b_t, + NULL, + NULL, + is_parallel + ); + // To capture the error populated from any of the threads + _Pragma("omp critical") + status = (status != BLIS_NOT_YET_IMPLEMENTED) ? status_l : status; + } + } + + return status; +} // End of function +#endif + +// region - GEMM DTRSM for right variants + +#define BLIS_DTRSM_SMALL_GEMM_8nx8m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ + /*K loop is broken into two seperate loops + each loop computes k/2 iterations */ \ + \ + int itr = (k_iter / 2); /*itr count for first loop*/\ + int itr2 = k_iter - itr; /*itr count for second loop*/\ + double *a01_2 = a01 + itr; /*a01 for second loop*/\ + double *b10_2 = b10 + (cs_b * itr); /*b10 for second loop*/\ + for (; itr > 0; itr--) \ + { \ + zmm0 = _mm512_loadu_pd((double const *)b10); \ + \ + zmm1 = _mm512_set1_pd(*(a01 + (p_lda * 0))); \ + zmm2 = _mm512_set1_pd(*(a01 + (p_lda * 1))); \ + zmm3 = _mm512_set1_pd(*(a01 + (p_lda * 2))); \ + zmm4 = _mm512_set1_pd(*(a01 + (p_lda * 3))); \ + zmm5 = _mm512_set1_pd(*(a01 + (p_lda * 4))); \ + zmm6 = _mm512_set1_pd(*(a01 + (p_lda * 5))); \ + zmm7 = _mm512_set1_pd(*(a01 + (p_lda * 6))); \ + zmm8 = _mm512_set1_pd(*(a01 + (p_lda * 7))); \ + \ + /*prefetch b10 4 iterations in advance*/ \ + _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + zmm9 = _mm512_fmadd_pd(zmm1, zmm0, zmm9 ); \ + zmm10 = _mm512_fmadd_pd(zmm2, zmm0, zmm10); \ + zmm11 = _mm512_fmadd_pd(zmm3, zmm0, zmm11); \ + zmm12 = _mm512_fmadd_pd(zmm4, zmm0, zmm12); \ + zmm13 = _mm512_fmadd_pd(zmm5, zmm0, zmm13); \ + zmm14 = _mm512_fmadd_pd(zmm6, zmm0, zmm14); \ + zmm15 = _mm512_fmadd_pd(zmm7, zmm0, zmm15); \ + zmm16 = _mm512_fmadd_pd(zmm8, zmm0, zmm16); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } \ + for (; itr2 > 0; itr2--) \ + { \ + zmm23 = _mm512_loadu_pd((double const *)b10_2); \ + \ + zmm17 = _mm512_set1_pd(*(a01_2 + (p_lda * 0))); \ + zmm18 = _mm512_set1_pd(*(a01_2 + (p_lda * 1))); \ + zmm19 = _mm512_set1_pd(*(a01_2 + (p_lda * 2))); \ + zmm20 = _mm512_set1_pd(*(a01_2 + (p_lda * 3))); \ + zmm21 = _mm512_set1_pd(*(a01_2 + (p_lda * 4))); \ + zmm22 = _mm512_set1_pd(*(a01_2 + (p_lda * 5))); \ + \ + _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + zmm24 = _mm512_fmadd_pd(zmm17, zmm23, zmm24); \ + zmm17 = _mm512_set1_pd(*(a01_2 + (p_lda * 6))); \ + zmm25 = _mm512_fmadd_pd(zmm18, zmm23, zmm25); \ + zmm18 = _mm512_set1_pd(*(a01_2 + (p_lda * 7))); \ + zmm26 = _mm512_fmadd_pd(zmm19, zmm23, zmm26); \ + zmm27 = _mm512_fmadd_pd(zmm20, zmm23, zmm27); \ + zmm28 = _mm512_fmadd_pd(zmm21, zmm23, zmm28); \ + zmm29 = _mm512_fmadd_pd(zmm22, zmm23, zmm29); \ + zmm30 = _mm512_fmadd_pd(zmm17, zmm23, zmm30); \ + zmm31 = _mm512_fmadd_pd(zmm18, zmm23, zmm31); \ + \ + a01_2 += 1; \ + b10_2 += cs_b; \ + } \ + \ + /*prefetch 8 columns of b11)*/ \ + _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + /*combine the results of both loops*/ \ + zmm9 = _mm512_add_pd(zmm9, zmm24); \ + _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + zmm10 = _mm512_add_pd(zmm10, zmm25); \ + _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + zmm11 = _mm512_add_pd(zmm11, zmm26); \ + _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + zmm12 = _mm512_add_pd(zmm12, zmm27); \ + _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + zmm13 = _mm512_add_pd(zmm13, zmm28); \ + _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + zmm14 = _mm512_add_pd(zmm14, zmm29); \ + _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + zmm15 = _mm512_add_pd(zmm15, zmm30); \ + _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + zmm16 = _mm512_add_pd(zmm16, zmm31); +/* +// alternative way to prrefetch b11 +// itr2 = itr2 + itr + 8; \ +// for(;itr2>0;itr2--) \ +// {\ +// zmm23 = _mm512_loadu_pd((double const *)b10_2); \ +// \ +// zmm17 = _mm512_set1_pd(*(a01_2 + p_lda * 0)); \ +// zmm18 = _mm512_set1_pd(*(a01_2 + p_lda * 1)); \ +// zmm19 = _mm512_set1_pd(*(a01_2 + p_lda * 2)); \ +// zmm20 = _mm512_set1_pd(*(a01_2 + p_lda * 3)); \ +// zmm21 = _mm512_set1_pd(*(a01_2 + p_lda * 4)); \ +// zmm22 = _mm512_set1_pd(*(a01_2 + p_lda * 5)); \ +// \ +// _mm_prefetch((b10_2 + 4*cs_b), _MM_HINT_T0); \ +// _mm_prefetch((b11 + (itr2-1)*cs_b), _MM_HINT_T0); \ +// zmm24 = _mm512_fmadd_pd(zmm17, zmm23, zmm24); \ +// zmm17 = _mm512_set1_pd(*(a01_2 + p_lda * 6)); \ +// zmm25 = _mm512_fmadd_pd(zmm18, zmm23, zmm25); \ +// zmm18 = _mm512_set1_pd(*(a01_2 + p_lda * 7)); \ +// zmm26 = _mm512_fmadd_pd(zmm19, zmm23, zmm26); \ +// zmm27 = _mm512_fmadd_pd(zmm20, zmm23, zmm27); \ +// zmm28 = _mm512_fmadd_pd(zmm21, zmm23, zmm28); \ +// zmm29 = _mm512_fmadd_pd(zmm22, zmm23, zmm29); \ +// zmm30 = _mm512_fmadd_pd(zmm17, zmm23, zmm30); \ +// zmm31 = _mm512_fmadd_pd(zmm18, zmm23, zmm31); \ +// \ +// a01_2 += 1;\ +// b10_2 += cs_b; \ +// }\ +*/ +/* +// alternative version of main loop +#define BLIS_DTRSM_SMALL_GEMM_8nx8m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ + int itr = k_iter - 8; \ + for(;itr>0;itr--) \ + {\ + zmm0 = _mm512_loadu_pd((double const *)b10); \ + \ + zmm1 = _mm512_set1_pd(*(a01 + p_lda * 0)); \ + zmm2 = _mm512_set1_pd(*(a01 + p_lda * 1)); \ + zmm3 = _mm512_set1_pd(*(a01 + p_lda * 2)); \ + zmm4 = _mm512_set1_pd(*(a01 + p_lda * 3)); \ + zmm5 = _mm512_set1_pd(*(a01 + p_lda * 4)); \ + zmm6 = _mm512_set1_pd(*(a01 + p_lda * 5)); \ + zmm7 = _mm512_set1_pd(*(a01 + p_lda * 6)); \ + zmm8 = _mm512_set1_pd(*(a01 + p_lda * 7)); \ + \ + _mm_prefetch((b10 + 4*cs_b), _MM_HINT_T0); \ + zmm9 = _mm512_fmadd_pd(zmm1, zmm0, zmm9 ); \ + zmm10 = _mm512_fmadd_pd(zmm2, zmm0, zmm10); \ + zmm11 = _mm512_fmadd_pd(zmm3, zmm0, zmm11); \ + zmm12 = _mm512_fmadd_pd(zmm4, zmm0, zmm12); \ + zmm13 = _mm512_fmadd_pd(zmm5, zmm0, zmm13); \ + zmm14 = _mm512_fmadd_pd(zmm6, zmm0, zmm14); \ + zmm15 = _mm512_fmadd_pd(zmm7, zmm0, zmm15); \ + zmm16 = _mm512_fmadd_pd(zmm8, zmm0, zmm16); \ + \ + a01 += 1;\ + b10 += cs_b; \ + }\ + itr += 8; \ + for(;itr>0;itr--) \ + {\ + zmm0 = _mm512_loadu_pd((double const *)b10); \ + \ + zmm1 = _mm512_set1_pd(*(a01 + p_lda * 0)); \ + zmm2 = _mm512_set1_pd(*(a01 + p_lda * 1)); \ + zmm3 = _mm512_set1_pd(*(a01 + p_lda * 2)); \ + zmm4 = _mm512_set1_pd(*(a01 + p_lda * 3)); \ + zmm5 = _mm512_set1_pd(*(a01 + p_lda * 4)); \ + zmm6 = _mm512_set1_pd(*(a01 + p_lda * 5)); \ + zmm7 = _mm512_set1_pd(*(a01 + p_lda * 6)); \ + zmm8 = _mm512_set1_pd(*(a01 + p_lda * 7)); \ + \ + _mm_prefetch((b10 + 4*cs_b), _MM_HINT_T0); \ + _mm_prefetch((b11 + (itr-1)*cs_b), _MM_HINT_T0); \ + zmm9 = _mm512_fmadd_pd(zmm1, zmm0, zmm9 ); \ + zmm10 = _mm512_fmadd_pd(zmm2, zmm0, zmm10); \ + zmm11 = _mm512_fmadd_pd(zmm3, zmm0, zmm11); \ + zmm12 = _mm512_fmadd_pd(zmm4, zmm0, zmm12); \ + zmm13 = _mm512_fmadd_pd(zmm5, zmm0, zmm13); \ + zmm14 = _mm512_fmadd_pd(zmm6, zmm0, zmm14); \ + zmm15 = _mm512_fmadd_pd(zmm7, zmm0, zmm15); \ + zmm16 = _mm512_fmadd_pd(zmm8, zmm0, zmm16); \ + \ + a01 += 1;\ + b10 += cs_b; \ + }\ +*/ + +#define BLIS_DTRSM_SMALL_GEMM_8nx4m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ + /*K loop is broken into two seperate loops + each loop computes k/2 iterations */ \ + \ + int itr = (k_iter / 2); /*itr count for first loop*/\ + int itr2 = k_iter - itr; /*itr count for second loop*/\ + double *a01_2 = a01 + itr; /*a01 for second loop*/\ + double *b10_2 = b10 + (cs_b * itr); /*b10 for second loop*/\ + for (; itr > 0; itr--) \ + { \ + ymm0 = _mm256_loadu_pd((double const *)(b10)); \ + \ + ymm1 = _mm256_broadcast_sd((a01 + (p_lda * 0))); \ + ymm2 = _mm256_broadcast_sd((a01 + (p_lda * 1))); \ + ymm3 = _mm256_broadcast_sd((a01 + (p_lda * 2))); \ + ymm4 = _mm256_broadcast_sd((a01 + (p_lda * 3))); \ + ymm5 = _mm256_broadcast_sd((a01 + (p_lda * 4))); \ + ymm6 = _mm256_broadcast_sd((a01 + (p_lda * 5))); \ + ymm7 = _mm256_broadcast_sd((a01 + (p_lda * 6))); \ + ymm8 = _mm256_broadcast_sd((a01 + (p_lda * 7))); \ + \ + _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_fmadd_pd(ymm1, ymm0, ymm9 ); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ + ymm12 = _mm256_fmadd_pd(ymm4, ymm0, ymm12); \ + ymm13 = _mm256_fmadd_pd(ymm5, ymm0, ymm13); \ + ymm14 = _mm256_fmadd_pd(ymm6, ymm0, ymm14); \ + ymm15 = _mm256_fmadd_pd(ymm7, ymm0, ymm15); \ + ymm16 = _mm256_fmadd_pd(ymm8, ymm0, ymm16); \ + \ + a01 += 1; \ + b10 += cs_b; \ + } \ + for (; itr2 > 0; itr2--) \ + { \ + ymm23 = _mm256_loadu_pd((double const *)(b10_2)); \ + \ + ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 0))); \ + ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 1))); \ + ymm19 = _mm256_broadcast_sd((a01_2 + (p_lda * 2))); \ + ymm20 = _mm256_broadcast_sd((a01_2 + (p_lda * 3))); \ + ymm21 = _mm256_broadcast_sd((a01_2 + (p_lda * 4))); \ + ymm22 = _mm256_broadcast_sd((a01_2 + (p_lda * 5))); \ + \ + _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + ymm24 = _mm256_fmadd_pd(ymm17, ymm23, ymm24); \ + ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 6))); \ + ymm25 = _mm256_fmadd_pd(ymm18, ymm23, ymm25); \ + ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 7))); \ + ymm26 = _mm256_fmadd_pd(ymm19, ymm23, ymm26); \ + ymm27 = _mm256_fmadd_pd(ymm20, ymm23, ymm27); \ + ymm28 = _mm256_fmadd_pd(ymm21, ymm23, ymm28); \ + ymm29 = _mm256_fmadd_pd(ymm22, ymm23, ymm29); \ + ymm30 = _mm256_fmadd_pd(ymm17, ymm23, ymm30); \ + ymm31 = _mm256_fmadd_pd(ymm18, ymm23, ymm31); \ + \ + a01_2 += 1; \ + b10_2 += cs_b; \ + } \ + /*combine the results of both loops*/ \ + _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_add_pd(ymm9, ymm24); \ + _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + ymm10 = _mm256_add_pd(ymm10, ymm25); \ + _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + ymm11 = _mm256_add_pd(ymm11, ymm26); \ + _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + ymm12 = _mm256_add_pd(ymm12, ymm27); \ + _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + ymm13 = _mm256_add_pd(ymm13, ymm28); \ + _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + ymm14 = _mm256_add_pd(ymm14, ymm29); \ + _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + ymm15 = _mm256_add_pd(ymm15, ymm30); \ + _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + ymm16 = _mm256_add_pd(ymm16, ymm31); + + +#define BLIS_DTRSM_SMALL_GEMM_8nx3m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ + /*K loop is broken into two seperate loops + each loop computes k/2 iterations */ \ + \ + int itr = (k_iter / 2); /*itr count for first loop*/\ + int itr2 = k_iter - itr; /*itr count for second loop*/\ + double *a01_2 = a01 + itr; /*a01 for second loop*/\ + double *b10_2 = b10 + (cs_b * itr); /*b10 for second loop*/\ + for (; itr > 0; itr--) \ + { \ + xmm5 = _mm_loadu_pd((b10)); /*load b10[0] and b10[1] into xmm5*/\ + ymm0 = _mm256_broadcast_sd((b10 + 2)); /*broadcast b10[2] into ymm0*/\ + ymm0 = _mm256_insertf64x2(ymm0, xmm5, 0); \ + /*ymm0 = {b10[0], b10[1], b10[2], b10[2]}*/\ + \ + ymm1 = _mm256_broadcast_sd((a01 + (p_lda * 0))); \ + ymm2 = _mm256_broadcast_sd((a01 + (p_lda * 1))); \ + ymm3 = _mm256_broadcast_sd((a01 + (p_lda * 2))); \ + ymm4 = _mm256_broadcast_sd((a01 + (p_lda * 3))); \ + ymm5 = _mm256_broadcast_sd((a01 + (p_lda * 4))); \ + ymm6 = _mm256_broadcast_sd((a01 + (p_lda * 5))); \ + ymm7 = _mm256_broadcast_sd((a01 + (p_lda * 6))); \ + ymm8 = _mm256_broadcast_sd((a01 + (p_lda * 7))); \ + \ + _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_fmadd_pd(ymm1, ymm0, ymm9 ); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ + ymm12 = _mm256_fmadd_pd(ymm4, ymm0, ymm12); \ + ymm13 = _mm256_fmadd_pd(ymm5, ymm0, ymm13); \ + ymm14 = _mm256_fmadd_pd(ymm6, ymm0, ymm14); \ + ymm15 = _mm256_fmadd_pd(ymm7, ymm0, ymm15); \ + ymm16 = _mm256_fmadd_pd(ymm8, ymm0, ymm16); \ + \ + a01 += 1; \ + b10 += cs_b; \ + } \ + for (; itr2 > 0; itr2--) \ + { \ + xmm0 = _mm_loadu_pd((b10_2)); \ + ymm23 = _mm256_broadcast_sd((b10_2 + 2)); \ + ymm23 = _mm256_insertf64x2(ymm23, xmm0, 0); \ + \ + ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 0))); \ + ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 1))); \ + ymm19 = _mm256_broadcast_sd((a01_2 + (p_lda * 2))); \ + ymm20 = _mm256_broadcast_sd((a01_2 + (p_lda * 3))); \ + ymm21 = _mm256_broadcast_sd((a01_2 + (p_lda * 4))); \ + ymm22 = _mm256_broadcast_sd((a01_2 + (p_lda * 5))); \ + \ + _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + ymm24 = _mm256_fmadd_pd(ymm17, ymm23, ymm24); \ + ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 6))); \ + ymm25 = _mm256_fmadd_pd(ymm18, ymm23, ymm25); \ + ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 7))); \ + ymm26 = _mm256_fmadd_pd(ymm19, ymm23, ymm26); \ + ymm27 = _mm256_fmadd_pd(ymm20, ymm23, ymm27); \ + ymm28 = _mm256_fmadd_pd(ymm21, ymm23, ymm28); \ + ymm29 = _mm256_fmadd_pd(ymm22, ymm23, ymm29); \ + ymm30 = _mm256_fmadd_pd(ymm17, ymm23, ymm30); \ + ymm31 = _mm256_fmadd_pd(ymm18, ymm23, ymm31); \ + \ + a01_2 += 1; \ + b10_2 += cs_b; \ + } \ + /*combine the results of both loops*/ \ + _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_add_pd(ymm9, ymm24); \ + _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + ymm10 = _mm256_add_pd(ymm10, ymm25); \ + _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + ymm11 = _mm256_add_pd(ymm11, ymm26); \ + _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + ymm12 = _mm256_add_pd(ymm12, ymm27); \ + _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + ymm13 = _mm256_add_pd(ymm13, ymm28); \ + _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + ymm14 = _mm256_add_pd(ymm14, ymm29); \ + _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + ymm15 = _mm256_add_pd(ymm15, ymm30); \ + _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + ymm16 = _mm256_add_pd(ymm16, ymm31); + + #define BLIS_DTRSM_SMALL_GEMM_8nx2m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ + /*K loop is broken into two seperate loops + each loop computes k/2 iterations */ \ + \ + int itr = (k_iter / 2); /*itr count for first loop*/\ + int itr2 = k_iter - itr; /*itr count for second loop*/\ + double *a01_2 = a01 + itr; /*a01 for second loop*/\ + double *b10_2 = b10 + (cs_b * itr); /*b10 for second loop*/\ + for (; itr > 0; itr--) \ + { \ + xmm5 = _mm_loadu_pd((double const *)(b10)); \ + ymm0 = _mm256_insertf64x2(ymm0, xmm5, 0); \ + \ + ymm1 = _mm256_broadcast_sd((a01 + (p_lda * 0))); \ + ymm2 = _mm256_broadcast_sd((a01 + (p_lda * 1))); \ + ymm3 = _mm256_broadcast_sd((a01 + (p_lda * 2))); \ + ymm4 = _mm256_broadcast_sd((a01 + (p_lda * 3))); \ + ymm5 = _mm256_broadcast_sd((a01 + (p_lda * 4))); \ + ymm6 = _mm256_broadcast_sd((a01 + (p_lda * 5))); \ + ymm7 = _mm256_broadcast_sd((a01 + (p_lda * 6))); \ + ymm8 = _mm256_broadcast_sd((a01 + (p_lda * 7))); \ + \ + _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_fmadd_pd(ymm1, ymm0, ymm9 ); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ + ymm12 = _mm256_fmadd_pd(ymm4, ymm0, ymm12); \ + ymm13 = _mm256_fmadd_pd(ymm5, ymm0, ymm13); \ + ymm14 = _mm256_fmadd_pd(ymm6, ymm0, ymm14); \ + ymm15 = _mm256_fmadd_pd(ymm7, ymm0, ymm15); \ + ymm16 = _mm256_fmadd_pd(ymm8, ymm0, ymm16); \ + \ + a01 += 1; \ + b10 += cs_b; \ + } \ + for (; itr2 > 0; itr2--) \ + { \ + xmm0 = _mm_loadu_pd((double const *)(b10_2)); \ + ymm23 = _mm256_insertf64x2(ymm23, xmm0, 0); \ + \ + ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 0))); \ + ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 1))); \ + ymm19 = _mm256_broadcast_sd((a01_2 + (p_lda * 2))); \ + ymm20 = _mm256_broadcast_sd((a01_2 + (p_lda * 3))); \ + ymm21 = _mm256_broadcast_sd((a01_2 + (p_lda * 4))); \ + ymm22 = _mm256_broadcast_sd((a01_2 + (p_lda * 5))); \ + \ + _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + ymm24 = _mm256_fmadd_pd(ymm17, ymm23, ymm24); \ + ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 6))); \ + ymm25 = _mm256_fmadd_pd(ymm18, ymm23, ymm25); \ + ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 7))); \ + ymm26 = _mm256_fmadd_pd(ymm19, ymm23, ymm26); \ + ymm27 = _mm256_fmadd_pd(ymm20, ymm23, ymm27); \ + ymm28 = _mm256_fmadd_pd(ymm21, ymm23, ymm28); \ + ymm29 = _mm256_fmadd_pd(ymm22, ymm23, ymm29); \ + ymm30 = _mm256_fmadd_pd(ymm17, ymm23, ymm30); \ + ymm31 = _mm256_fmadd_pd(ymm18, ymm23, ymm31); \ + \ + a01_2 += 1; \ + b10_2 += cs_b; \ + } \ + /*combine the results of both loops*/ \ + _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_add_pd(ymm9, ymm24); \ + _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + ymm10 = _mm256_add_pd(ymm10, ymm25); \ + _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + ymm11 = _mm256_add_pd(ymm11, ymm26); \ + _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + ymm12 = _mm256_add_pd(ymm12, ymm27); \ + _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + ymm13 = _mm256_add_pd(ymm13, ymm28); \ + _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + ymm14 = _mm256_add_pd(ymm14, ymm29); \ + _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + ymm15 = _mm256_add_pd(ymm15, ymm30); \ + _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + ymm16 = _mm256_add_pd(ymm16, ymm31); + +#define BLIS_DTRSM_SMALL_GEMM_8nx1m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) \ + /*K loop is broken into two seperate loops + each loop computes k/2 iterations */ \ + \ + int itr = (k_iter / 2); /*itr count for first loop*/\ + int itr2 = k_iter - itr; /*itr count for second loop*/\ + double *a01_2 = a01 + itr; /*a01 for second loop*/\ + double *b10_2 = b10 + (cs_b * itr); /*b10 for second loop*/\ + for (; itr > 0; itr--) \ + { \ + ymm0 = _mm256_broadcast_sd(b10); \ + \ + ymm1 = _mm256_broadcast_sd((a01 + (p_lda * 0))); \ + ymm2 = _mm256_broadcast_sd((a01 + (p_lda * 1))); \ + ymm3 = _mm256_broadcast_sd((a01 + (p_lda * 2))); \ + ymm4 = _mm256_broadcast_sd((a01 + (p_lda * 3))); \ + ymm5 = _mm256_broadcast_sd((a01 + (p_lda * 4))); \ + ymm6 = _mm256_broadcast_sd((a01 + (p_lda * 5))); \ + ymm7 = _mm256_broadcast_sd((a01 + (p_lda * 6))); \ + ymm8 = _mm256_broadcast_sd((a01 + (p_lda * 7))); \ + \ + _mm_prefetch((b10 + 4 * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_fmadd_pd(ymm1, ymm0, ymm9 ); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm0, ymm10); \ + ymm11 = _mm256_fmadd_pd(ymm3, ymm0, ymm11); \ + ymm12 = _mm256_fmadd_pd(ymm4, ymm0, ymm12); \ + ymm13 = _mm256_fmadd_pd(ymm5, ymm0, ymm13); \ + ymm14 = _mm256_fmadd_pd(ymm6, ymm0, ymm14); \ + ymm15 = _mm256_fmadd_pd(ymm7, ymm0, ymm15); \ + ymm16 = _mm256_fmadd_pd(ymm8, ymm0, ymm16); \ + \ + a01 += 1; \ + b10 += cs_b; \ + } \ + for (; itr2 > 0; itr2--) \ + { \ + ymm23 = _mm256_broadcast_sd(b10); \ + \ + ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 0))); \ + ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 1))); \ + ymm19 = _mm256_broadcast_sd((a01_2 + (p_lda * 2))); \ + ymm20 = _mm256_broadcast_sd((a01_2 + (p_lda * 3))); \ + ymm21 = _mm256_broadcast_sd((a01_2 + (p_lda * 4))); \ + ymm22 = _mm256_broadcast_sd((a01_2 + (p_lda * 5))); \ + \ + _mm_prefetch((b10_2 + 4 * cs_b), _MM_HINT_T0); \ + ymm24 = _mm256_fmadd_pd(ymm17, ymm23, ymm24); \ + ymm17 = _mm256_broadcast_sd((a01_2 + (p_lda * 6))); \ + ymm25 = _mm256_fmadd_pd(ymm18, ymm23, ymm25); \ + ymm18 = _mm256_broadcast_sd((a01_2 + (p_lda * 7))); \ + ymm26 = _mm256_fmadd_pd(ymm19, ymm23, ymm26); \ + ymm27 = _mm256_fmadd_pd(ymm20, ymm23, ymm27); \ + ymm28 = _mm256_fmadd_pd(ymm21, ymm23, ymm28); \ + ymm29 = _mm256_fmadd_pd(ymm22, ymm23, ymm29); \ + ymm30 = _mm256_fmadd_pd(ymm17, ymm23, ymm30); \ + ymm31 = _mm256_fmadd_pd(ymm18, ymm23, ymm31); \ + \ + a01_2 += 1; \ + b10_2 += cs_b; \ + } \ + /*combine the results of both loops*/ \ + _mm_prefetch((b11 + (0) * cs_b), _MM_HINT_T0); \ + ymm9 = _mm256_add_pd(ymm9, ymm24); \ + _mm_prefetch((b11 + (1) * cs_b), _MM_HINT_T0); \ + ymm10 = _mm256_add_pd(ymm10, ymm25); \ + _mm_prefetch((b11 + (2) * cs_b), _MM_HINT_T0); \ + ymm11 = _mm256_add_pd(ymm11, ymm26); \ + _mm_prefetch((b11 + (3) * cs_b), _MM_HINT_T0); \ + ymm12 = _mm256_add_pd(ymm12, ymm27); \ + _mm_prefetch((b11 + (4) * cs_b), _MM_HINT_T0); \ + ymm13 = _mm256_add_pd(ymm13, ymm28); \ + _mm_prefetch((b11 + (5) * cs_b), _MM_HINT_T0); \ + ymm14 = _mm256_add_pd(ymm14, ymm29); \ + _mm_prefetch((b11 + (6) * cs_b), _MM_HINT_T0); \ + ymm15 = _mm256_add_pd(ymm15, ymm30); \ + _mm_prefetch((b11 + (7) * cs_b), _MM_HINT_T0); \ + ymm16 = _mm256_add_pd(ymm16, ymm31); + + + +#define BLIS_DTRSM_SMALL_GEMM_4nx8m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 8x1 block of B10*/ \ + ymm0 = _mm256_loadu_pd((double const *)b10); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/ \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + ymm10 = _mm256_fmadd_pd(ymm2, ymm1, ymm10); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx4m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 4x1 block of B10*/ \ + ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/ \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/ \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx3m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 3x1 block of B10*/ \ + xmm5 = _mm_loadu_pd((double const *)(b10)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/ \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx2m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 2x1 block of B10*/ \ + xmm5 = _mm_loadu_pd((double const *)(b10)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/ \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_4nx1m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 1x1 block of B10*/ \ + ymm0 = _mm256_broadcast_sd((double const *)b10); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 3)); /*A01[0][3]*/ \ + ymm9 = _mm256_fmadd_pd(ymm2, ymm0, ymm9); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_3nx8m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 8x1 block of B10*/ \ + ymm0 = _mm256_loadu_pd((double const *)b10); \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + ymm8 = _mm256_fmadd_pd(ymm2, ymm1, ymm8); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_3nx4m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 4x1 block of B10*/ \ + ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/ \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_3nx3m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 3x1 block of B10*/ \ + xmm5 = _mm_loadu_pd((double const *)(b10)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_3nx2m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 2x1 block of B10*/ \ + xmm5 = _mm_loadu_pd((double const *)(b10)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_3nx1m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 1x1 block of B10*/ \ + ymm0 = _mm256_broadcast_sd((double const *)b10); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 2)); /*A01[0][2]*/ \ + ymm7 = _mm256_fmadd_pd(ymm2, ymm0, ymm7); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx8m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 8x1 block of B10*/ \ + ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/ \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); /*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/ \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + ymm6 = _mm256_fmadd_pd(ymm2, ymm1, ymm6); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx4m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 4x1 block of B10*/ \ + ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/ \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx3m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 3x1 block of B10*/ \ + xmm5 = _mm_loadu_pd((double const *)(b10)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx2m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 2x1 block of B10*/ \ + xmm5 = _mm_loadu_pd((double const *)(b10)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_2nx1m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 1x1 block of B10*/ \ + ymm0 = _mm256_broadcast_sd((double const *)b10); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 1)); /*A01[0][1]*/ \ + ymm5 = _mm256_fmadd_pd(ymm2, ymm0, ymm5); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx8m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 8x1 block of B10*/ \ + ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/ \ + ymm1 = _mm256_loadu_pd((double const *)(b10 + 4)); /*B10[4][0] B10[5][0] B10[6][0] B10[7][0]*/ \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + ymm4 = _mm256_fmadd_pd(ymm2, ymm1, ymm4); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx4m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 4x1 block of B10*/ \ + ymm0 = _mm256_loadu_pd((double const *)b10); /*B10[0][0] B10[1][0] B10[2][0] B10[3][0]*/ \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx3m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 3x1 block of B10*/ \ + xmm5 = _mm_loadu_pd((double const *)(b10)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b10 + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx2m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 2x1 block of B10*/ \ + xmm5 = _mm_loadu_pd((double const *)(b10)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +#define BLIS_DTRSM_SMALL_GEMM_1nx1m(a01, b10, cs_b, p_lda, k_iter) \ + for (k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/ \ + { \ + /*load 1x1 block of B10*/ \ + ymm0 = _mm256_broadcast_sd((double const *)b10); \ + \ + /*broadcast 1st row of A01*/ \ + ymm2 = _mm256_broadcast_sd((double const *)(a01 + p_lda * 0)); /*A01[0][0]*/ \ + ymm3 = _mm256_fmadd_pd(ymm2, ymm0, ymm3); \ + \ + a01 += 1; /*move to next row*/ \ + b10 += cs_b; \ + } + +// endregion - GEMM DTRSM for right variants + +// region - pre/post DTRSM macros for right variants + +#define BLIS_PRE_DTRSM_SMALL_8x8(AlphaVal, b11, cs_b) \ + /*gemm_output = (B11 * alpha) - gemm_output*/ \ + zmm31 = _mm512_set1_pd(AlphaVal); \ + \ + zmm0 = _mm512_loadu_pd((double const *)(b11 + (0*cs_b))); \ + zmm9 = _mm512_fmsub_pd(zmm0, zmm31, zmm9 ); /*zmm9 = (zmm0 * zmm31) - zmm9*/\ + \ + zmm1 = _mm512_loadu_pd((double const *)(b11 + (1*cs_b))); \ + zmm10 = _mm512_fmsub_pd(zmm1, zmm31, zmm10); /*zmm10 = (zmm1 * zmm31) - zmm10*/\ + \ + zmm2 = _mm512_loadu_pd((double const *)(b11 + (2*cs_b))); \ + zmm11 = _mm512_fmsub_pd(zmm2, zmm31, zmm11); /*zmm11 = (zmm2 * zmm31) - zmm11*/\ + \ + zmm3 = _mm512_loadu_pd((double const *)(b11 + (3*cs_b))); \ + zmm12 = _mm512_fmsub_pd(zmm3, zmm31, zmm12); /*zmm12 = (zmm3 * zmm31) - zmm12*/\ + \ + zmm4 = _mm512_loadu_pd((double const *)(b11 + (4*cs_b))); \ + zmm13 = _mm512_fmsub_pd(zmm4, zmm31, zmm13); /*zmm13 = (zmm4 * zmm31) - zmm13*/\ + \ + zmm5 = _mm512_loadu_pd((double const *)(b11 + (5*cs_b))); \ + zmm14 = _mm512_fmsub_pd(zmm5, zmm31, zmm14); /*zmm14 = (zmm5 * zmm31) - zmm14*/\ + \ + zmm6 = _mm512_loadu_pd((double const *)(b11 + (6*cs_b))); \ + zmm15 = _mm512_fmsub_pd(zmm6, zmm31, zmm15); /*zmm15 = (zmm6 * zmm31) - zmm15*/\ + \ + zmm7 = _mm512_loadu_pd((double const *)(b11 + (7*cs_b))); \ + zmm16 = _mm512_fmsub_pd(zmm7, zmm31, zmm16); + +#define BLIS_PRE_DTRSM_SMALL_8x4(AlphaVal, b11, cs_b) \ + /*gemm_output = (B11 * alpha) - gemm_output*/ \ + ymm31 = _mm256_broadcast_sd(&AlphaVal); \ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + 0 * cs_b)); \ + ymm9 = _mm256_fmsub_pd(ymm0, ymm31, ymm9); \ + \ + ymm1 = _mm256_loadu_pd((double const *)(b11 + 1 * cs_b)); \ + ymm10 = _mm256_fmsub_pd(ymm1, ymm31, ymm10); \ + \ + ymm2 = _mm256_loadu_pd((double const *)(b11 + 2 * cs_b)); \ + ymm11 = _mm256_fmsub_pd(ymm2, ymm31, ymm11); \ + \ + ymm3 = _mm256_loadu_pd((double const *)(b11 + 3 * cs_b)); \ + ymm12 = _mm256_fmsub_pd(ymm3, ymm31, ymm12); \ + \ + ymm4 = _mm256_loadu_pd((double const *)(b11 + 4 * cs_b)); \ + ymm13 = _mm256_fmsub_pd(ymm4, ymm31, ymm13); \ + \ + ymm5 = _mm256_loadu_pd((double const *)(b11 + 5 * cs_b)); \ + ymm14 = _mm256_fmsub_pd(ymm5, ymm31, ymm14); \ + \ + ymm6 = _mm256_loadu_pd((double const *)(b11 + 6 * cs_b)); \ + ymm15 = _mm256_fmsub_pd(ymm6, ymm31, ymm15); \ + \ + ymm7 = _mm256_loadu_pd((double const *)(b11 + 7 * cs_b)); \ + ymm16 = _mm256_fmsub_pd(ymm7, ymm31, ymm16); + +#define BLIS_PRE_DTRSM_SMALL_8x3(AlphaVal, b11, cs_b) \ + /*gemm_output = (B11 * alpha) - gemm_output*/ \ + ymm31 = _mm256_broadcast_sd(&AlphaVal); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 0 * cs_b)); /*xmm5 = {b11[0], b11[1]}*/\ + ymm0 = _mm256_broadcast_sd((b11 + 2 + 0 * cs_b)); /*ymm0 = {b11[2], b11[2], b11[2], b11[2]}*/\ + ymm0 = _mm256_insertf64x2(ymm0, xmm5, 0); /*ymm0 = {b11[0], b11[1], b11[2], b11[2]}*/\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm31, ymm9); /*ymm9 = (ymm0 * ymm31) - ymm9*/\ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 1 * cs_b)); \ + ymm1 = _mm256_broadcast_sd((b11 + 2 + 1 * cs_b)); \ + ymm1 = _mm256_insertf64x2(ymm1, xmm5, 0); \ + ymm10 = _mm256_fmsub_pd(ymm1, ymm31, ymm10); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 2 * cs_b)); \ + ymm2 = _mm256_broadcast_sd((b11 + 2 + 2 * cs_b)); \ + ymm2 = _mm256_insertf64x2(ymm2, xmm5, 0); \ + ymm11 = _mm256_fmsub_pd(ymm2, ymm31, ymm11); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 3 * cs_b)); \ + ymm3 = _mm256_broadcast_sd((b11 + 2 + 3 * cs_b)); \ + ymm3 = _mm256_insertf64x2(ymm3, xmm5, 0); \ + ymm12 = _mm256_fmsub_pd(ymm3, ymm31, ymm12); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 4 * cs_b)); \ + ymm4 = _mm256_broadcast_sd((b11 + 2 + 4 * cs_b)); \ + ymm4 = _mm256_insertf64x2(ymm4, xmm5, 0); \ + ymm13 = _mm256_fmsub_pd(ymm4, ymm31, ymm13); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 5 * cs_b)); \ + ymm5 = _mm256_broadcast_sd((b11 + 2 + 5 * cs_b)); \ + ymm5 = _mm256_insertf64x2(ymm5, xmm5, 0); \ + ymm14 = _mm256_fmsub_pd(ymm5, ymm31, ymm14); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 6 * cs_b)); \ + ymm6 = _mm256_broadcast_sd((b11 + 2 + 6 * cs_b)); \ + ymm6 = _mm256_insertf64x2(ymm6, xmm5, 0); \ + ymm15 = _mm256_fmsub_pd(ymm6, ymm31, ymm15); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 7 * cs_b)); \ + ymm7 = _mm256_broadcast_sd((b11 + 2 + 7 * cs_b)); \ + ymm7 = _mm256_insertf64x2(ymm7, xmm5, 0); \ + ymm16 = _mm256_fmsub_pd(ymm7, ymm31, ymm16); + +#define BLIS_PRE_DTRSM_SMALL_8x2(AlphaVal, b11, cs_b) \ + /*gemm_output = (B11 * alpha) - gemm_output*/ \ + ymm31 = _mm256_broadcast_sd(&AlphaVal); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 0 * cs_b)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm31, ymm9); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 1 * cs_b)); \ + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0);\ + ymm10 = _mm256_fmsub_pd(ymm1, ymm31, ymm10); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 2 * cs_b)); \ + ymm2 = _mm256_insertf128_pd(ymm2, xmm5, 0);\ + ymm11 = _mm256_fmsub_pd(ymm2, ymm31, ymm11); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 3 * cs_b)); \ + ymm3 = _mm256_insertf128_pd(ymm3, xmm5, 0);\ + ymm12 = _mm256_fmsub_pd(ymm3, ymm31, ymm12); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 4 * cs_b)); \ + ymm4 = _mm256_insertf128_pd(ymm4, xmm5, 0);\ + ymm13 = _mm256_fmsub_pd(ymm4, ymm31, ymm13); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 5 * cs_b)); \ + ymm5 = _mm256_insertf128_pd(ymm5, xmm5, 0);\ + ymm14 = _mm256_fmsub_pd(ymm5, ymm31, ymm14); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 6 * cs_b)); \ + ymm6 = _mm256_insertf128_pd(ymm6, xmm5, 0);\ + ymm15 = _mm256_fmsub_pd(ymm6, ymm31, ymm15); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + 7 * cs_b)); \ + ymm7 = _mm256_insertf128_pd(ymm7, xmm5, 0);\ + ymm16 = _mm256_fmsub_pd(ymm7, ymm31, ymm16); + +#define BLIS_PRE_DTRSM_SMALL_8x1(AlphaVal, b11, cs_b) \ + /*gemm_output = (B11 * alpha) - gemm_output*/ \ + ymm31 = _mm256_broadcast_sd(&AlphaVal); \ + \ + xmm5 = _mm_loadl_pd(xmm5, (double const *)(b11 + 0 * cs_b)); \ + ymm0 = _mm256_insertf64x2(ymm0, xmm5, 0); \ + ymm9 = _mm256_fmsub_pd(ymm0, ymm31, ymm9); \ + \ + xmm5 = _mm_loadl_pd(xmm5, (double const *)(b11 + 1 * cs_b)); \ + ymm1 = _mm256_insertf64x2(ymm1, xmm5, 0); \ + ymm10 = _mm256_fmsub_pd(ymm1, ymm31, ymm10); \ + \ + xmm5 = _mm_loadl_pd(xmm5, (double const *)(b11 + 2 * cs_b)); \ + ymm2 = _mm256_insertf64x2(ymm2, xmm5, 0); \ + ymm11 = _mm256_fmsub_pd(ymm2, ymm31, ymm11); \ + \ + xmm5 = _mm_loadl_pd(xmm5, (double const *)(b11 + 3 * cs_b)); \ + ymm3 = _mm256_insertf64x2(ymm3, xmm5, 0); \ + ymm12 = _mm256_fmsub_pd(ymm3, ymm31, ymm12); \ + \ + xmm5 = _mm_loadl_pd(xmm5, (double const *)(b11 + 4 * cs_b)); \ + ymm4 = _mm256_insertf64x2(ymm4, xmm5, 0); \ + ymm13 = _mm256_fmsub_pd(ymm4, ymm31, ymm13); \ + \ + xmm5 = _mm_loadl_pd(xmm5, (double const *)(b11 + 5 * cs_b)); \ + ymm5 = _mm256_insertf64x2(ymm5, xmm5, 0); \ + ymm14 = _mm256_fmsub_pd(ymm5, ymm31, ymm14); \ + \ + xmm5 = _mm_loadl_pd(xmm5, (double const *)(b11 + 6 * cs_b)); \ + ymm6 = _mm256_insertf64x2(ymm6, xmm5, 0); \ + ymm15 = _mm256_fmsub_pd(ymm6, ymm31, ymm15); \ + \ + xmm5 = _mm_loadl_pd(xmm5, (double const *)(b11 + 7 * cs_b)); \ + ymm7 = _mm256_insertf64x2(ymm7, xmm5, 0); \ + ymm16 = _mm256_fmsub_pd(ymm7, ymm31, ymm16); + +#define BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); \ + \ + ymm0 = _mm256_loadu_pd((double const *)b11); \ + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); \ + \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); \ + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); \ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); \ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); \ + \ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); \ + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); \ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2)); \ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 2 + 4)); \ + \ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); \ + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); \ + \ + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3)); \ + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b * 3 + 4)); \ + \ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); \ + ymm10 = _mm256_fmsub_pd(ymm1, ymm15, ymm10); + +#define BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2 + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + +#define BLIS_POST_DTRSM_SMALL_3N_3M(b11, cs_b) \ + \ + xmm5 = _mm256_castpd256_pd128(ymm3); \ + _mm_storeu_pd((double *)(b11), xmm5); \ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); \ + xmm5 = _mm256_castpd256_pd128(ymm5); \ + _mm_storeu_pd((double *)(b11 + cs_b), xmm5); \ + _mm_storel_pd((b11 + cs_b + 2), _mm256_extractf128_pd(ymm5, 1)); \ + xmm5 = _mm256_castpd256_pd128(ymm7); \ + _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); \ + _mm_storel_pd((b11 + cs_b * 2 + 2), _mm256_extractf128_pd(ymm7, 1)); + +#define BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + +#define BLIS_POST_DTRSM_SMALL_3N_2M(b11, cs_b) \ + \ + xmm5 = _mm256_castpd256_pd128(ymm3); \ + _mm_storeu_pd((double *)(b11), xmm5); \ + xmm5 = _mm256_castpd256_pd128(ymm5); \ + _mm_storeu_pd((double *)(b11 + cs_b), xmm5); \ + xmm5 = _mm256_castpd256_pd128(ymm7); \ + _mm_storeu_pd((double *)(b11 + cs_b * 2), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/ \ + \ + ymm0 = _mm256_broadcast_sd((double const *)b11); \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); \ + \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); \ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); \ + \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 2)); \ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + +#define BLIS_POST_DTRSM_SMALL_3N_1M(b11, cs_b) \ + \ + _mm_storel_pd((b11 + cs_b * 0), _mm256_castpd256_pd128(ymm3)); \ + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5)); \ + _mm_storel_pd((b11 + cs_b * 2), _mm256_castpd256_pd128(ymm7)); + +#define BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b * 1 + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + +#define BLIS_POST_DTRSM_SMALL_2N_3M(b11, cs_b) \ + \ + xmm5 = _mm256_castpd256_pd128(ymm3); \ + _mm_storeu_pd((double *)(b11), xmm5); \ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); \ + xmm5 = _mm256_castpd256_pd128(ymm5); \ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); \ + _mm_storel_pd((b11 + cs_b * 1 + 2), _mm256_extractf128_pd(ymm5, 1)); + +#define BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b * 1)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + +#define BLIS_POST_DTRSM_SMALL_2N_2M(b11, cs_b) \ + \ + xmm5 = _mm256_castpd256_pd128(ymm3); \ + _mm_storeu_pd((double *)(b11), xmm5); \ + xmm5 = _mm256_castpd256_pd128(ymm5); \ + _mm_storeu_pd((double *)(b11 + cs_b * 1), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/ \ + \ + ymm0 = _mm256_broadcast_sd((double const *)b11); \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); \ + \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); \ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + +#define BLIS_POST_DTRSM_SMALL_2N_1M(b11, cs_b) \ + \ + _mm_storel_pd(b11, _mm256_castpd256_pd128(ymm3)); \ + _mm_storel_pd((b11 + cs_b * 1), _mm256_castpd256_pd128(ymm5)); + +#define BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11)); \ + ymm0 = _mm256_broadcast_sd((double const *)(b11 + 2)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + +#define BLIS_POST_DTRSM_SMALL_1N_3M(b11, cs_b) \ + xmm5 = _mm256_castpd256_pd128(ymm3); \ + _mm_storeu_pd((double *)(b11), xmm5); \ + _mm_storel_pd((b11 + 2), _mm256_extractf128_pd(ymm3, 1)); + +#define BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/ \ + \ + xmm5 = _mm_loadu_pd((double const *)(b11)); \ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + +#define BLIS_POST_DTRSM_SMALL_1N_2M(b11, cs_b) \ + \ + xmm5 = _mm256_castpd256_pd128(ymm3); \ + _mm_storeu_pd((double *)(b11), xmm5); + +#define BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal, b11, cs_b) \ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/ \ + \ + ymm0 = _mm256_broadcast_sd((double const *)b11); \ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + +#define BLIS_POST_DTRSM_SMALL_1N_1M(b11, cs_b) \ + \ + _mm_storel_pd(b11, _mm256_castpd256_pd128(ymm3)); + +// endregion - pre/post DTRSM macros for right variants + +// RUNN - RLTN +BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB_AVX512 + ( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + dim_t m = bli_obj_length(b); //number of rows + dim_t n = bli_obj_width(b); // number of columns + dim_t d_mr = 8, d_nr = 8; + + bool transa = bli_obj_has_trans(a); + dim_t cs_a, rs_a; + double ones = 1.0; + + // Swap rs_a & cs_a in case of non-transpose. + if (transa) + { + cs_a = bli_obj_col_stride(a); // column stride of A + rs_a = bli_obj_row_stride(a); // row stride of A + } + else + { + cs_a = bli_obj_row_stride(a); // row stride of A + rs_a = bli_obj_col_stride(a); // column stride of A + } + + dim_t cs_b = bli_obj_col_stride(b); // column stride of B + + dim_t i, j, k; + dim_t k_iter; + + bool is_unitdiag = bli_obj_has_unit_diag(a); + + double AlphaVal = *(double *)AlphaObj->buffer; + double *restrict L = a->buffer; // pointer to matrix A + double *B = bli_obj_buffer_at_off(b); // pointer to matrix B + + double *a01, *a11, *b10, *b11; // pointers for GEMM and TRSM blocks + + gint_t required_packing_A = 1; + mem_t local_mem_buf_A_s = {0}; + double *D_A_pack = NULL; // pointer to A01 pack buffer + double d11_pack[d_mr] __attribute__((aligned(64))); // buffer for diagonal A pack + rntm_t rntm; + + bli_rntm_init_from_global(&rntm); + bli_rntm_set_num_threads_only(1, &rntm); + bli_membrk_rntm_set_membrk(&rntm); + + siz_t buffer_size = bli_pool_block_size( + bli_membrk_pool( + bli_packbuf_index(BLIS_BITVAL_BUFFER_FOR_A_BLOCK), + bli_rntm_membrk(&rntm))); + + if ((d_nr * n * sizeof(double)) > buffer_size) + return BLIS_NOT_YET_IMPLEMENTED; + + if (required_packing_A == 1) + { + // Get the buffer from the pool. + bli_membrk_acquire_m(&rntm, + buffer_size, + BLIS_BITVAL_BUFFER_FOR_A_BLOCK, + &local_mem_buf_A_s); // acquire memory for A01 pack + if (FALSE == bli_mem_is_alloc(&local_mem_buf_A_s)) + return BLIS_NULL_POINTER; + D_A_pack = bli_mem_buffer(&local_mem_buf_A_s); + if (NULL == D_A_pack) + return BLIS_NULL_POINTER; + } + + __m512d zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7, zmm8, zmm9, zmm10, zmm11; + __m512d zmm12, zmm13, zmm14, zmm15, zmm16, zmm17, zmm18, zmm19, zmm20, zmm21; + __m512d zmm22, zmm23, zmm24, zmm25, zmm26, zmm27, zmm28, zmm29, zmm30, zmm31; + __m256d ymm0, ymm1, ymm2, ymm3, ymm4, ymm5, ymm6, ymm7, ymm8, ymm9, ymm10, ymm11; + __m256d ymm12, ymm13, ymm14, ymm15, ymm16, ymm17, ymm18, ymm19, ymm20, ymm21; + __m256d ymm22, ymm23, ymm24, ymm25, ymm26, ymm27, ymm28, ymm29, ymm30, ymm31; + __m128d xmm5, xmm0; + + /* + Performs solving TRSM for 8 rows at a time from 0 to n/8 in steps of d_nr + a. Load and pack A (a01 block), the size of packing 8x8 to 8x(n-8) + First there will be no GEMM and no packing of a01 because it is only TRSM + b. Using packed a01 block and b10 block perform GEMM operation + c. Use GEMM outputs, perform TRSM operation using a11, b11 and update B + d. Repeat b for m cols of B in steps of d_mr + */ + for (j = 0; (j + d_nr - 1) < n; j += d_nr) //loop along 'N' direction + { + a01 = L + j * rs_a; //pointer to block of A to be used in GEMM + a11 = L + j * cs_a + j * rs_a; //pointer to block of A to be used for TRSM + + dim_t p_lda = j; //packed leading dimension + + // perform copy of A to packed buffer D_A_pack + if (transa) + { + /* + Pack current A block (a01) into packed buffer memory D_A_pack + a. This a10 block is used in GEMM portion only and this + a01 block size will be increasing by d_nr for every next iteration + until it reaches 8x(n-8) which is the maximum GEMM alone block size in A + b. This packed buffer is reused to calculate all m cols of B matrix + */ + bli_dtrsm_small_pack_avx512('R', j, 1, a01, cs_a, D_A_pack, p_lda, d_nr); + /* + Pack 8 diagonal elements of A block into an array + a. This helps to utilize cache line efficiently in TRSM operation + b. store ones when input is unit diagonal + */ + dtrsm_small_pack_diag_element_avx512(is_unitdiag, a11, cs_a, d11_pack, d_nr); + } + else + { + bli_dtrsm_small_pack_avx512('R', j, 0, a01, rs_a, D_A_pack, p_lda, d_nr); + dtrsm_small_pack_diag_element_avx512(is_unitdiag, a11, rs_a, d11_pack, d_nr); + } + + /* + a. Perform GEMM using a01, b10. + b. Perform TRSM on a11, b11 + c. This loop GEMM+TRSM loops operates with 8x6 block size + along m dimension for every d_mr columns of B10 where + packed A buffer is reused in computing all m cols of B. + d. Same approach is used in remaining fringe cases. + */ + for (i = 0; (i + d_mr - 1) < m; i += d_mr) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; + BLIS_SET_ZMM_REG_ZEROS + /* + Perform GEMM between a01 and b10 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8nx8m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11); + /* + Load b11 of size 8x8 and multiply with alpha + Add the GEMM output to b11 + and perform TRSM operation. + */ + BLIS_PRE_DTRSM_SMALL_8x8(AlphaVal, b11, cs_b) + + + + /* + Compute 8x8 TRSM block by using GEMM block output in register + a. The 8x8 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : zmm9 zmm10 zmm11 zmm12 zmm13 zmm14 zmm15 zmm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to n-1 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[i][N] } /A11[N][N] + * from i=0 + * + * For example 5th column (B11[5]) -= ((B11[0] * A11[0][5]) + (B11[1] * A11[2][5]) + + * (B11[2] * A11[2][5]) + (B11[3] * A11[3][5]) + + * (B11[4] * A11[4][5])) / A11[5][5] + * zmm14 -= ((zmm9 * A11[0][5]) + (zmm10 * A11[2][5]) + + * (zmm11 * A11[2][5]) + (zmm12 * A11[3][5]) + + * (zmm13 * A11[4][5])) / A11[5][5] + */ + + + // extract a00 + zmm0 = _mm512_set1_pd(*(d11_pack + 0)); + zmm9 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm9, zmm0); // zmm9 /= zmm0 + _mm512_storeu_pd((double *)(b11 + (0 * cs_b)), zmm9); + + // extract a11 + zmm1 = _mm512_set1_pd(*(d11_pack + 1)); + zmm2 = _mm512_set1_pd(*(a11 + (1 * rs_a))); + zmm10 = _mm512_fnmadd_pd(zmm2, zmm9, zmm10); + zmm3 = _mm512_set1_pd(*(a11 + (2 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm3, zmm9, zmm11); + zmm4 = _mm512_set1_pd(*(a11 + (3 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm4, zmm9, zmm12); + zmm5 = _mm512_set1_pd(*(a11 + (4 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm5, zmm9, zmm13); + zmm6 = _mm512_set1_pd(*(a11 + (5 * rs_a))); + zmm14 = _mm512_fnmadd_pd(zmm6, zmm9, zmm14); // zmm14 -= A11[0][5] * zmm9 + zmm7 = _mm512_set1_pd(*(a11 + (6 * rs_a))); + zmm15 = _mm512_fnmadd_pd(zmm7, zmm9, zmm15); + zmm8 = _mm512_set1_pd(*(a11 + (7 * rs_a))); + zmm16 = _mm512_fnmadd_pd(zmm8, zmm9, zmm16); + zmm10 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm10, zmm1); + a11 += cs_a; + _mm512_storeu_pd((double *)(b11 + (1 * cs_b)), zmm10); + + // extract a22 + zmm0 = _mm512_set1_pd(*(d11_pack + 2)); + zmm2 = _mm512_set1_pd(*(a11 + (2 * rs_a))); + zmm11 = _mm512_fnmadd_pd(zmm2, zmm10, zmm11); + zmm3 = _mm512_set1_pd(*(a11 + (3 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm3, zmm10, zmm12); + zmm4 = _mm512_set1_pd(*(a11 + (4 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm4, zmm10, zmm13); + zmm5 = _mm512_set1_pd(*(a11 + (5 * rs_a))); + zmm14 = _mm512_fnmadd_pd(zmm5, zmm10, zmm14); // zmm14 -= A11[1][5] * zmm10 + zmm6 = _mm512_set1_pd(*(a11 + (6 * rs_a))); + zmm15 = _mm512_fnmadd_pd(zmm6, zmm10, zmm15); + zmm7 = _mm512_set1_pd(*(a11 + (7 * rs_a))); + zmm16 = _mm512_fnmadd_pd(zmm7, zmm10, zmm16); + zmm11 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm11, zmm0); + a11 += cs_a; + _mm512_storeu_pd((double *)(b11 + (2 * cs_b)), zmm11); + + // extract a33 + zmm1 = _mm512_set1_pd(*(d11_pack + 3)); + zmm2 = _mm512_set1_pd(*(a11 + (3 * rs_a))); + zmm12 = _mm512_fnmadd_pd(zmm2, zmm11, zmm12); + zmm3 = _mm512_set1_pd(*(a11 + (4 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm3, zmm11, zmm13); + zmm4 = _mm512_set1_pd(*(a11 + (5 * rs_a))); + zmm14 = _mm512_fnmadd_pd(zmm4, zmm11, zmm14); // zmm14 -= A11[2][5] * zmm11 + zmm5 = _mm512_set1_pd(*(a11 + (6 * rs_a))); + zmm15 = _mm512_fnmadd_pd(zmm5, zmm11, zmm15); + zmm6 = _mm512_set1_pd(*(a11 + (7 * rs_a))); + zmm16 = _mm512_fnmadd_pd(zmm6, zmm11, zmm16); + zmm12 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm12, zmm1); + a11 += cs_a; + _mm512_storeu_pd((double *)(b11 + (3 * cs_b)), zmm12); + + // extract a44 + zmm0 = _mm512_set1_pd(*(d11_pack + 4)); + zmm2 = _mm512_set1_pd(*(a11 + (4 * rs_a))); + zmm13 = _mm512_fnmadd_pd(zmm2, zmm12, zmm13); + zmm3 = _mm512_set1_pd(*(a11 + (5 * rs_a))); + zmm14 = _mm512_fnmadd_pd(zmm3, zmm12, zmm14); // zmm14 -= A11[3][5] * zmm12 + zmm4 = _mm512_set1_pd(*(a11 + (6 * rs_a))); + zmm15 = _mm512_fnmadd_pd(zmm4, zmm12, zmm15); + zmm5 = _mm512_set1_pd(*(a11 + (7 * rs_a))); + zmm16 = _mm512_fnmadd_pd(zmm5, zmm12, zmm16); + zmm13 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm13, zmm0); + a11 += cs_a; + _mm512_storeu_pd((double *)(b11 + (4 * cs_b)), zmm13); + + // extract a55 + zmm1 = _mm512_set1_pd(*(d11_pack + 5)); + zmm2 = _mm512_set1_pd(*(a11 + (5 * rs_a))); + zmm14 = _mm512_fnmadd_pd(zmm2, zmm13, zmm14); // zmm14 -= A11[4][5] * zmm13 + zmm3 = _mm512_set1_pd(*(a11 + (6 * rs_a))); + zmm15 = _mm512_fnmadd_pd(zmm3, zmm13, zmm15); + zmm4 = _mm512_set1_pd(*(a11 + (7 * rs_a))); + zmm16 = _mm512_fnmadd_pd(zmm4, zmm13, zmm16); + zmm14 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm14, zmm1); // zmm14 /= A11[5][5] + a11 += cs_a; + _mm512_storeu_pd((double *)(b11 + (5 * cs_b)), zmm14); + + // extract a66 + zmm0 = _mm512_set1_pd(*(d11_pack + 6)); + zmm2 = _mm512_set1_pd(*(a11 + (6 * rs_a))); + zmm15 = _mm512_fnmadd_pd(zmm2, zmm14, zmm15); + zmm3 = _mm512_set1_pd(*(a11 + (7 * rs_a))); + zmm16 = _mm512_fnmadd_pd(zmm3, zmm14, zmm16); + zmm15 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm15, zmm0); + a11 += cs_a; + _mm512_storeu_pd((double *)(b11 + (6 * cs_b)), zmm15); + + // extract a77 + zmm1 = _mm512_set1_pd(*(d11_pack + 7)); + zmm2 = _mm512_set1_pd(*(a11 + (7 * rs_a))); + zmm16 = _mm512_fnmadd_pd(zmm2, zmm15, zmm16); + zmm16 = DTRSM_SMALL_DIV_OR_SCALE_AVX512(zmm16, zmm1); + _mm512_storeu_pd((double *)(b11 + (7 * cs_b)), zmm16); + } + dim_t m_remainder = m - i; + if(m_remainder) + { + if (m_remainder >= 4) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; + BLIS_SET_YMM_REG_ZEROS_AVX512 + /* + Perform GEMM between a01 and b10 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8nx4m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + /* + Load b11 of size 8x4 and multiply with alpha + Add the GEMM output to b11 + and perform TRSM operation. + */ + BLIS_PRE_DTRSM_SMALL_8x4(AlphaVal, b11, cs_b) + + /* + Compute 8x4 TRSM block by using GEMM block output in register + a. The 8x4 input (gemm outputs) are stored in combinations of ymm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to n-1 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[i][N] } /A11[N][N] + * from i=0 + * + * For example 5th column (B11[5]) -= ((B11[0] * A11[0][5]) + (B11[1] * A11[2][5]) + + * (B11[2] * A11[2][5]) + (B11[3] * A11[3][5]) + + * (B11[4] * A11[4][5])) / A11[5][5] + */ + + // extract a00 + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm256_storeu_pd((double *)(b11 + (0 * cs_b)), ymm9); + + // extract a11 + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm2 = _mm256_broadcast_sd((a11 + (1 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm3 = _mm256_broadcast_sd((a11 + (2 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm3, ymm9, ymm11); + ymm4 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm9, ymm12); + ymm5 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm5, ymm9, ymm13); + ymm6 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm6, ymm9, ymm14); + ymm7 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm7, ymm9, ymm15); + ymm8 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm8, ymm9, ymm16); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + a11 += cs_a; + _mm256_storeu_pd((double *)(b11 + (1 * cs_b)), ymm10); + + // extract a22 + ymm0 = _mm256_broadcast_sd((d11_pack + 2)); + ymm2 = _mm256_broadcast_sd((a11 + (2 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm3 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm3, ymm10, ymm12); + ymm4 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm4, ymm10, ymm13); + ymm5 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm5, ymm10, ymm14); + ymm6 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm6, ymm10, ymm15); + ymm7 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm7, ymm10, ymm16); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + a11 += cs_a; + _mm256_storeu_pd((double *)(b11 + (2 * cs_b)), ymm11); + + // extract a33 + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm2 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm3 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm11, ymm13); + ymm4 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm4, ymm11, ymm14); + ymm5 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm5, ymm11, ymm15); + ymm6 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm6, ymm11, ymm16); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + a11 += cs_a; + _mm256_storeu_pd((double *)(b11 + (3 * cs_b)), ymm12); + + // extract a44 + ymm0 = _mm256_broadcast_sd((d11_pack + 4)); + ymm2 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm3 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); + ymm4 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); + ymm5 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm5, ymm12, ymm16); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + a11 += cs_a; + _mm256_storeu_pd((double *)(b11 + (4 * cs_b)), ymm13); + + // extract a55 + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm2 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm3 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm3, ymm13, ymm15); + ymm4 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm4, ymm13, ymm16); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + a11 += cs_a; + _mm256_storeu_pd((double *)(b11 + (5 * cs_b)), ymm14); + + // extract a66 + ymm0 = _mm256_broadcast_sd((d11_pack + 6)); + ymm2 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm3 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm3, ymm14, ymm16); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm0); + a11 += cs_a; + _mm256_storeu_pd((double *)(b11 + (6 * cs_b)), ymm15); + + // extract a77 + ymm1 = _mm256_broadcast_sd((d11_pack + 7)); + ymm2 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm2, ymm15, ymm16); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm1); + _mm256_storeu_pd((double *)(b11 + (7 * cs_b)), ymm16); + m_remainder -= 4; + i += 4; + } + if (m_remainder == 3) //loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; //pointer to block of A to be used for TRSM + b10 = B + i; //pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; //pointer to block of B to be used for TRSM + + k_iter = j; + BLIS_SET_YMM_REG_ZEROS_AVX512 + /* + Perform GEMM between a01 and b10 blocks + For first iteration there will be no GEMM operation + where k_iter are zero + */ + BLIS_DTRSM_SMALL_GEMM_8nx3m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + /* + Load b11 of size 8x3 and multiply with alpha + Add the GEMM output to b11 + and perform TRSM operation. + */ + BLIS_PRE_DTRSM_SMALL_8x3(AlphaVal, b11, cs_b) + /* + Compute 8x3 TRSM block by using GEMM block output in register + a. The 8x3 input (gemm outputs) are stored in combinations of ymm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to n-1 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[i][N] } /A11[N][N] + * from i=0 + * + * For example 5th column (B11[5]) -= ((B11[0] * A11[0][5]) + (B11[1] * A11[2][5]) + + * (B11[2] * A11[2][5]) + (B11[3] * A11[3][5]) + + * (B11[4] * A11[4][5])) / A11[5][5] + */ + + + // extract a00 + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm_storeu_pd((double *)(b11 + (0 * cs_b)), _mm256_castpd256_pd128(ymm9)); + _mm_storel_pd((double *)(b11 + (0 * cs_b) + 2), _mm256_extractf64x2_pd(ymm9, 1)); + + // extract a11 + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm2 = _mm256_broadcast_sd((a11 + (1 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm3 = _mm256_broadcast_sd((a11 + (2 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm3, ymm9, ymm11); + ymm4 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm9, ymm12); + ymm5 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm5, ymm9, ymm13); + ymm6 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm6, ymm9, ymm14); + ymm7 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm7, ymm9, ymm15); + ymm8 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm8, ymm9, ymm16); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (1 * cs_b)), _mm256_castpd256_pd128(ymm10)); + _mm_storel_pd((double *)(b11 + (1 * cs_b) + 2), _mm256_extractf64x2_pd(ymm10, 1)); + + // extract a22 + ymm0 = _mm256_broadcast_sd((d11_pack + 2)); + ymm2 = _mm256_broadcast_sd((a11 + (2 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm3 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm3, ymm10, ymm12); + ymm4 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm4, ymm10, ymm13); + ymm5 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm5, ymm10, ymm14); + ymm6 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm6, ymm10, ymm15); + ymm7 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm7, ymm10, ymm16); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (2 * cs_b)), _mm256_castpd256_pd128(ymm11)); + _mm_storel_pd((double *)(b11 + (2 * cs_b) + 2), _mm256_extractf64x2_pd(ymm11, 1)); + + // extract a33 + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm2 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm3 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm11, ymm13); + ymm4 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm4, ymm11, ymm14); + ymm5 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm5, ymm11, ymm15); + ymm6 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm6, ymm11, ymm16); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (3 * cs_b)), _mm256_castpd256_pd128(ymm12)); + _mm_storel_pd((double *)(b11 + (3 * cs_b) + 2), _mm256_extractf64x2_pd(ymm12, 1)); + + // extract a44 + ymm0 = _mm256_broadcast_sd((d11_pack + 4)); + ymm2 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm3 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); + ymm4 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); + ymm5 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm5, ymm12, ymm16); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (4 * cs_b)), _mm256_castpd256_pd128(ymm13)); + _mm_storel_pd((double *)(b11 + (4 * cs_b) + 2), _mm256_extractf64x2_pd(ymm13, 1)); + + // extract a55 + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm2 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm3 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm3, ymm13, ymm15); + ymm4 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm4, ymm13, ymm16); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (5 * cs_b)), _mm256_castpd256_pd128(ymm14)); + _mm_storel_pd((double *)(b11 + (5 * cs_b) + 2), _mm256_extractf64x2_pd(ymm14, 1)); + + // extract a66 + ymm0 = _mm256_broadcast_sd((d11_pack + 6)); + ymm2 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm3 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm3, ymm14, ymm16); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm0); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (6 * cs_b)), _mm256_castpd256_pd128(ymm15)); + _mm_storel_pd((double *)(b11 + (6 * cs_b) + 2), _mm256_extractf64x2_pd(ymm15, 1)); + + // extract a77 + ymm1 = _mm256_broadcast_sd((d11_pack + 7)); + ymm2 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm2, ymm15, ymm16); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm1); + _mm_storeu_pd((double *)(b11 + (7 * cs_b)), _mm256_castpd256_pd128(ymm16)); + _mm_storel_pd((double *)(b11 + (7 * cs_b) + 2), _mm256_extractf64x2_pd(ymm16, 1)); + m_remainder -= 3; + i += 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; + b10 = B + i; + b11 = B + i + j * cs_b; + + k_iter = j; + BLIS_SET_YMM_REG_ZEROS_AVX512 + BLIS_DTRSM_SMALL_GEMM_8nx2m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + BLIS_PRE_DTRSM_SMALL_8x2(AlphaVal, b11, cs_b) + /* + Compute 8x2 TRSM block by using GEMM block output in register + a. The 8x2 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to n-1 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[i][N] } /A11[N][N] + * from i=0 + * + * For example 5th column (B11[5]) -= ((B11[0] * A11[0][5]) + (B11[1] * A11[2][5]) + + * (B11[2] * A11[2][5]) + (B11[3] * A11[3][5]) + + * (B11[4] * A11[4][5])) / A11[5][5] + */ + // extract a00 + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm_storeu_pd((double *)(b11 + (0 * cs_b)), _mm256_castpd256_pd128(ymm9)); + + // extract a11 + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm2 = _mm256_broadcast_sd((a11 + (1 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm3 = _mm256_broadcast_sd((a11 + (2 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm3, ymm9, ymm11); + ymm4 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm9, ymm12); + ymm5 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm5, ymm9, ymm13); + ymm6 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm6, ymm9, ymm14); + ymm7 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm7, ymm9, ymm15); + ymm8 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm8, ymm9, ymm16); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (1 * cs_b)), _mm256_castpd256_pd128(ymm10)); + + // extract a22 + ymm0 = _mm256_broadcast_sd((d11_pack + 2)); + ymm2 = _mm256_broadcast_sd((a11 + (2 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm3 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm3, ymm10, ymm12); + ymm4 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm4, ymm10, ymm13); + ymm5 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm5, ymm10, ymm14); + ymm6 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm6, ymm10, ymm15); + ymm7 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm7, ymm10, ymm16); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (2 * cs_b)), _mm256_castpd256_pd128(ymm11)); + + // extract a33 + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm2 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm3 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm11, ymm13); + ymm4 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm4, ymm11, ymm14); + ymm5 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm5, ymm11, ymm15); + ymm6 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm6, ymm11, ymm16); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (3 * cs_b)), _mm256_castpd256_pd128(ymm12)); + + // extract a44 + ymm0 = _mm256_broadcast_sd((d11_pack + 4)); + ymm2 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm3 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); + ymm4 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); + ymm5 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm5, ymm12, ymm16); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (4 * cs_b)), _mm256_castpd256_pd128(ymm13)); + + // extract a55 + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm2 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm3 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm3, ymm13, ymm15); + ymm4 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm4, ymm13, ymm16); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (5 * cs_b)), _mm256_castpd256_pd128(ymm14)); + + // extract a66 + ymm0 = _mm256_broadcast_sd((d11_pack + 6)); + ymm2 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm3 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm3, ymm14, ymm16); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm0); + a11 += cs_a; + _mm_storeu_pd((double *)(b11 + (6 * cs_b)), _mm256_castpd256_pd128(ymm15)); + + // extract a77 + ymm1 = _mm256_broadcast_sd((d11_pack + 7)); + ymm2 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm2, ymm15, ymm16); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm1); + _mm_storeu_pd((double *)(b11 + (7 * cs_b)), _mm256_castpd256_pd128(ymm16)); + m_remainder -= 2; + i += 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; + b10 = B + i; + b11 = B + i + j * cs_b; + + k_iter = j; + BLIS_SET_YMM_REG_ZEROS_AVX512 + BLIS_DTRSM_SMALL_GEMM_8nx1m_AVX512(a01, b10, cs_b, p_lda, k_iter, b11) + BLIS_PRE_DTRSM_SMALL_8x1(AlphaVal, b11, cs_b) + /* + Compute 8x1 TRSM block by using GEMM block output in register + a. The 8x1 input (gemm outputs) are stored in combinations of zmm registers + row : 0 1 2 3 4 5 6 7 + register : ymm9 ymm10 ymm11 ymm12 ymm13 ymm14 ymm15 ymm16 + b. Towards the end TRSM output will be stored back into b11 + */ + + /* + * to n-1 + * B11[Nth column] = GEMM(Nth column) - Σ { B11[i] * A11[i][N] } /A11[N][N] + * from i=0 + * + * For example 5th column (B11[5]) -= ((B11[0] * A11[0][5]) + (B11[1] * A11[2][5]) + + * (B11[2] * A11[2][5]) + (B11[3] * A11[3][5]) + + * (B11[4] * A11[4][5])) / A11[5][5] + */ + // extract a00 + ymm0 = _mm256_broadcast_sd((d11_pack + 0)); + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + _mm_storel_pd((double *)(b11 + (0 * cs_b)), _mm256_castpd256_pd128(ymm9)); + + // extract a11 + ymm1 = _mm256_broadcast_sd((d11_pack + 1)); + ymm2 = _mm256_broadcast_sd((a11 + (1 * rs_a))); + ymm10 = _mm256_fnmadd_pd(ymm2, ymm9, ymm10); + ymm3 = _mm256_broadcast_sd((a11 + (2 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm3, ymm9, ymm11); + ymm4 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm4, ymm9, ymm12); + ymm5 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm5, ymm9, ymm13); + ymm6 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm6, ymm9, ymm14); + ymm7 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm7, ymm9, ymm15); + ymm8 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm8, ymm9, ymm16); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm1); + a11 += cs_a; + _mm_storel_pd((double *)(b11 + (1 * cs_b)), _mm256_castpd256_pd128(ymm10)); + + // extract a22 + ymm0 = _mm256_broadcast_sd((d11_pack + 2)); + ymm2 = _mm256_broadcast_sd((a11 + (2 * rs_a))); + ymm11 = _mm256_fnmadd_pd(ymm2, ymm10, ymm11); + ymm3 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm3, ymm10, ymm12); + ymm4 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm4, ymm10, ymm13); + ymm5 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm5, ymm10, ymm14); + ymm6 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm6, ymm10, ymm15); + ymm7 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm7, ymm10, ymm16); + ymm11 = DTRSM_SMALL_DIV_OR_SCALE(ymm11, ymm0); + a11 += cs_a; + _mm_storel_pd((double *)(b11 + (2 * cs_b)), _mm256_castpd256_pd128(ymm11)); + + // extract a33 + ymm1 = _mm256_broadcast_sd((d11_pack + 3)); + ymm2 = _mm256_broadcast_sd((a11 + (3 * rs_a))); + ymm12 = _mm256_fnmadd_pd(ymm2, ymm11, ymm12); + ymm3 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm3, ymm11, ymm13); + ymm4 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm4, ymm11, ymm14); + ymm5 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm5, ymm11, ymm15); + ymm6 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm6, ymm11, ymm16); + ymm12 = DTRSM_SMALL_DIV_OR_SCALE(ymm12, ymm1); + a11 += cs_a; + _mm_storel_pd((double *)(b11 + (3 * cs_b)), _mm256_castpd256_pd128(ymm12)); + + // extract a44 + ymm0 = _mm256_broadcast_sd((d11_pack + 4)); + ymm2 = _mm256_broadcast_sd((a11 + (4 * rs_a))); + ymm13 = _mm256_fnmadd_pd(ymm2, ymm12, ymm13); + ymm3 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm3, ymm12, ymm14); + ymm4 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm4, ymm12, ymm15); + ymm5 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm5, ymm12, ymm16); + ymm13 = DTRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); + a11 += cs_a; + _mm_storel_pd((double *)(b11 + (4 * cs_b)), _mm256_castpd256_pd128(ymm13)); + + // extract a55 + ymm1 = _mm256_broadcast_sd((d11_pack + 5)); + ymm2 = _mm256_broadcast_sd((a11 + (5 * rs_a))); + ymm14 = _mm256_fnmadd_pd(ymm2, ymm13, ymm14); + ymm3 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm3, ymm13, ymm15); + ymm4 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm4, ymm13, ymm16); + ymm14 = DTRSM_SMALL_DIV_OR_SCALE(ymm14, ymm1); + a11 += cs_a; + _mm_storel_pd((double *)(b11 + (5 * cs_b)), _mm256_castpd256_pd128(ymm14)); + + // extract a66 + ymm0 = _mm256_broadcast_sd((d11_pack + 6)); + ymm2 = _mm256_broadcast_sd((a11 + (6 * rs_a))); + ymm15 = _mm256_fnmadd_pd(ymm2, ymm14, ymm15); + ymm3 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm3, ymm14, ymm16); + ymm15 = DTRSM_SMALL_DIV_OR_SCALE(ymm15, ymm0); + a11 += cs_a; + _mm_storel_pd((double *)(b11 + (6 * cs_b)), _mm256_castpd256_pd128(ymm15)); + + // extract a77 + ymm1 = _mm256_broadcast_sd((d11_pack + 7)); + ymm2 = _mm256_broadcast_sd((a11 + (7 * rs_a))); + ymm16 = _mm256_fnmadd_pd(ymm2, ymm15, ymm16); + ymm16 = DTRSM_SMALL_DIV_OR_SCALE(ymm16, ymm1); + _mm_storel_pd((double *)(b11 + (7 * cs_b)), _mm256_castpd256_pd128(ymm16)); + m_remainder -= 1; + i += 1; + } + } + } + + dim_t n_remainder = n - j; + + if (n_remainder >= 4) + { + a01 = L + j * rs_a; // pointer to block of A to be used in GEMM + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + dim_t x = 0; + for (x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); + bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); + bli_dcopys(*(a01 + rs_a * 2), *(ptr_a10_dup + (p_lda * 2))); + bli_dcopys(*(a01 + rs_a * 3), *(ptr_a10_dup + (p_lda * 3))); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = p_lda / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 1) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 2) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 3) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 3) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 1) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 2) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 3) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 3) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2) + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 3) + 3)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 2) + 2)); + ymm3 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 3) + 3)); + } + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = 0; (i + d_mr - 1) < m; i += d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx8m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_4x8(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm4, ymm10); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm6, ymm10); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + a11 += cs_a; + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + ymm10 = _mm256_fnmadd_pd(ymm1, ymm8, ymm10); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + ymm10 = DTRSM_SMALL_DIV_OR_SCALE(ymm10, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + (cs_b + 4)), ymm6); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2) + 4), ymm8); + _mm256_storeu_pd((double *)(b11 + (cs_b * 3)), ymm9); + _mm256_storeu_pd((double *)(b11 + (cs_b * 3) + 4), ymm10); + } + + dim_t m_remainder = m - i; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 3))); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + _mm256_storeu_pd((double *)(b11 + (cs_b * 3)), ymm9); + + m_remainder -= 4; + i += 4; + } + + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx3m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)(&AlphaVal)); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 3))); + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 3) + 2)); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + (cs_b * 2)), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + (cs_b * 3)), _mm256_castpd256_pd128(ymm9)); + + _mm_storel_pd((double *)b11 + 2, _mm256_extractf128_pd(ymm3, 1)); + _mm_storel_pd((double *)(b11 + cs_b + 2), _mm256_extractf128_pd(ymm5, 1)); + _mm_storel_pd((double *)(b11 + (cs_b * 2) + 2), _mm256_extractf128_pd(ymm7, 1)); + _mm_storel_pd((double *)(b11 + (cs_b * 3) + 2), _mm256_extractf128_pd(ymm9, 1)); + + m_remainder -= 3; + i += 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx2m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + xmm5 = _mm_loadu_pd((double const *)(b11 + (cs_b * 3))); + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0); + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm_storeu_pd((double *)b11, _mm256_castpd256_pd128(ymm3)); + _mm_storeu_pd((double *)(b11 + cs_b), _mm256_castpd256_pd128(ymm5)); + _mm_storeu_pd((double *)(b11 + (cs_b * 2)), _mm256_castpd256_pd128(ymm7)); + _mm_storeu_pd((double *)(b11 + (cs_b * 3)), _mm256_castpd256_pd128(ymm9)); + + m_remainder -= 2; + i += 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_4nx1m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_broadcast_sd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + ymm0 = _mm256_broadcast_sd((double const *)(b11 + (cs_b * 3))); + // B11[0][3] B11[1][3] B11[2][3] B11[3][3] + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9); + // B11[0-3][3] * alpha -= ymm6 + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm3, ymm9); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm5, ymm9); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + a11 += cs_a; + + // extract a33 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 3)); + + //(Row 3): FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (3 * rs_a))); + ymm9 = _mm256_fnmadd_pd(ymm1, ymm7, ymm9); + + ymm9 = DTRSM_SMALL_DIV_OR_SCALE(ymm9, ymm0); + + _mm_storel_pd((b11 + (cs_b * 0)), _mm256_castpd256_pd128(ymm3)); + _mm_storel_pd((b11 + (cs_b * 1)), _mm256_castpd256_pd128(ymm5)); + _mm_storel_pd((b11 + (cs_b * 2)), _mm256_castpd256_pd128(ymm7)); + _mm_storel_pd((b11 + (cs_b * 3)), _mm256_castpd256_pd128(ymm9)); + + m_remainder -= 1; + i += 1; + } + j += 4; + n_remainder -= 4; + } + + if (n_remainder == 3) + { + a01 = L + j * rs_a; // pointer to block of A to be used in GEMM + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + dim_t x = 0; + for (x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + p_lda * 0)); + bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + p_lda * 1)); + bli_dcopys(*(a01 + rs_a * 2), *(ptr_a10_dup + p_lda * 2)); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = p_lda / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 1) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 2) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 1) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 2) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 2) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 2) + 2)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 1) + 1)); + ymm2 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 2) + 2)); + } + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = 0; (i + d_mr - 1) < m; i += d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx8m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); + // B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); + // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); + // B11[4-7][1] * alpha -= ymm3 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm1 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2) + 4)); + // B11[4][2] B11[5][2] B11[6][2] B11[7][2] + + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + ymm8 = _mm256_fmsub_pd(ymm1, ymm15, ymm8); + // B11[4-7][2] * alpha -= ymm5 + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm4, ymm8); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + ymm8 = _mm256_fnmadd_pd(ymm1, ymm6, ymm8); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + ymm8 = DTRSM_SMALL_DIV_OR_SCALE(ymm8, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2) + 4), ymm8); + } + + dim_t m_remainder = m - i; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + (cs_b * 2))); + // B11[0][2] B11[1][2] B11[2][2] B11[3][2] + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7); + // B11[0-3][2] * alpha -= ymm4 + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + (cs_b * 2)), ymm7); + + m_remainder -= 4; + i += 4; + } + + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx3m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_3M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_3M(b11, cs_b) + + m_remainder -= 3; + i += 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx2m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_2M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_2M(b11, cs_b) + + m_remainder -= 2; + i += 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_3nx1m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_3N_1M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm3, ymm7); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + a11 += cs_a; + + // extract a22 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 2)); + + //(row 2):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (2 * rs_a))); + ymm7 = _mm256_fnmadd_pd(ymm1, ymm5, ymm7); + + ymm7 = DTRSM_SMALL_DIV_OR_SCALE(ymm7, ymm0); + + BLIS_POST_DTRSM_SMALL_3N_1M(b11, cs_b) + + m_remainder -= 1; + i += 1; + } + j += 3; + n_remainder -= 3; + } + else if (n_remainder == 2) + { + a01 = L + j * rs_a; // pointer to block of A to be used in GEMM + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + dim_t x = 0; + for (x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + (p_lda * 0))); + bli_dcopys(*(a01 + rs_a * 1), *(ptr_a10_dup + (p_lda * 1))); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = p_lda / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 1) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 1) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 1) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + if (transa) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (cs_a * 1) + 1)); + } + else + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (rs_a * 1) + 1)); + } + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = 0; (i + d_mr - 1) < m; i += d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + /*Fill zeros into ymm registers used in gemm accumulations */ + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + ymm6 = _mm256_setzero_pd(); + ymm7 = _mm256_setzero_pd(); + ymm8 = _mm256_setzero_pd(); + ymm9 = _mm256_setzero_pd(); + ymm10 = _mm256_setzero_pd(); + ymm15 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx8m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); + // B11[4-7][0] * alpha-= ymm1 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b + 4)); + // B11[4][1] B11[5][1] B11[6][1] B11[7][1] + + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + ymm6 = _mm256_fmsub_pd(ymm1, ymm15, ymm6); + // B11[4-7][1] * alpha -= ymm3 + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + ymm6 = _mm256_fnmadd_pd(ymm1, ymm4, ymm6); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + ymm6 = DTRSM_SMALL_DIV_OR_SCALE(ymm6, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + _mm256_storeu_pd((double *)(b11 + cs_b + 4), ymm6); + } + + dim_t m_remainder = m - i; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + + ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b)); + // B11[0][1] B11[1][1] B11[2][1] B11[3][1] + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5); + // B11[0-3][1] * alpha-= ymm2 + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + cs_b), ymm5); + + m_remainder -= 4; + i += 4; + } + + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx3m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_3M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_3M(b11, cs_b) + + m_remainder -= 3; + i += 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx2m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_2M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_2M(b11, cs_b) + + m_remainder -= 2; + i += 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm5 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_2nx1m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_2N_1M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + // extract a11 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack + 1)); + + //(row 1):FMA operations + ymm1 = _mm256_broadcast_sd((double const *)(a11 + (1 * rs_a))); + ymm5 = _mm256_fnmadd_pd(ymm1, ymm3, ymm5); + + ymm5 = DTRSM_SMALL_DIV_OR_SCALE(ymm5, ymm0); + + BLIS_POST_DTRSM_SMALL_2N_1M(b11, cs_b) + + m_remainder -= 1; + i += 1; + } + j += 2; + n_remainder -= 2; + } + else if (n_remainder == 1) + { + a01 = L + j * rs_a; // pointer to block of A to be used in GEMM + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + + double *ptr_a10_dup = D_A_pack; + + dim_t p_lda = j; // packed leading dimension + // perform copy of A to packed buffer D_A_pack + + if (transa) + { + dim_t x = 0; + for (x = 0; x < p_lda; x += 1) + { + bli_dcopys(*(a01 + rs_a * 0), *(ptr_a10_dup + p_lda * 0)); + ptr_a10_dup += 1; + a01 += cs_a; + } + } + else + { + dim_t loop_count = p_lda / 4; + + for (dim_t x = 0; x < loop_count; x++) + { + ymm15 = _mm256_loadu_pd((double const *)(a01 + (rs_a * 0) + (x * 4))); + _mm256_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (x * 4)), ymm15); + } + + dim_t remainder_loop_count = p_lda - loop_count * 4; + + __m128d xmm0; + if (remainder_loop_count != 0) + { + xmm0 = _mm_loadu_pd((double const *)(a01 + (rs_a * 0) + (loop_count * 4))); + _mm_storeu_pd((double *)(ptr_a10_dup + (p_lda * 0) + (loop_count * 4)), xmm0); + } + } + + ymm4 = _mm256_broadcast_sd((double const *)&ones); + if (!is_unitdiag) + { + // broadcast diagonal elements of A11 + ymm0 = _mm256_broadcast_sd((double const *)(a11)); + ymm1 = _mm256_broadcast_sd((double const *)&ones); + ymm2 = _mm256_broadcast_sd((double const *)&ones); + ymm3 = _mm256_broadcast_sd((double const *)&ones); + + ymm0 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm1 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm1 = _mm256_blend_pd(ymm0, ymm1, 0x0C); +#ifdef BLIS_DISABLE_TRSM_PREINVERSION + ymm4 = ymm1; +#endif +#ifdef BLIS_ENABLE_TRSM_PREINVERSION + ymm4 = _mm256_div_pd(ymm4, ymm1); +#endif + } + _mm256_storeu_pd((double *)(d11_pack), ymm4); + + for (i = 0; (i + d_mr - 1) < m; i += d_mr) // loop along 'M' direction + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + ymm4 = _mm256_setzero_pd(); + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx8m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); + + ymm0 = _mm256_loadu_pd((double const *)b11); + // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm1 = _mm256_loadu_pd((double const *)(b11 + 4)); + // B11[4][0] B11[5][0] B11[6][0] B11[7][0] + + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); + // B11[0-3][0] * alpha -= ymm0 + ymm4 = _mm256_fmsub_pd(ymm1, ymm15, ymm4); + // B11[4-7][0] * alpha-= ymm1 + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + ymm4 = DTRSM_SMALL_DIV_OR_SCALE(ymm4, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + _mm256_storeu_pd((double *)(b11 + 4), ymm4); + } + + dim_t m_remainder = m - i; + if (m_remainder >= 4) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx4m(a01, b10, cs_b, p_lda, k_iter) + + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); // register to hold alpha + + ymm0 = _mm256_loadu_pd((double const *)b11); // B11[0][0] B11[1][0] B11[2][0] B11[3][0] + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3); // B11[0-3][0] * alpha -= ymm0 + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + _mm256_storeu_pd((double *)b11, ymm3); + + m_remainder -= 4; + i += 4; + } + + if (m_remainder == 3) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx3m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_3M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_3M(b11, cs_b) + + m_remainder -= 3; + i += 3; + } + else if (m_remainder == 2) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx2m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_2M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_2M(b11, cs_b) + + m_remainder -= 2; + i += 2; + } + else if (m_remainder == 1) + { + a01 = D_A_pack; + a11 = L + j * cs_a + j * rs_a; // pointer to block of A to be used for TRSM + b10 = B + i; // pointer to block of B to be used in GEMM + b11 = B + i + j * cs_b; // pointer to block of B to be used for TRSM + + k_iter = j; // number of GEMM operations to be done(in blocks of 4x4) + + ymm3 = _mm256_setzero_pd(); + + /// GEMM implementation starts/// + BLIS_DTRSM_SMALL_GEMM_1nx1m(a01, b10, cs_b, p_lda, k_iter) + + BLIS_PRE_DTRSM_SMALL_1N_1M(AlphaVal, b11, cs_b) + + /// implement TRSM/// + + // extract a00 + ymm0 = _mm256_broadcast_sd((double const *)(d11_pack)); + ymm3 = DTRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); + + BLIS_POST_DTRSM_SMALL_1N_1M(b11, cs_b) + + m_remainder -= 1; + i += 1; + } + j += 1; + n_remainder -= 1; + } + + if ((required_packing_A == 1) && bli_mem_is_alloc(&local_mem_buf_A_s)) + { + bli_membrk_release(&rntm, + &local_mem_buf_A_s); + } + return BLIS_SUCCESS; +} + + +// RLNN - RUTN +BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB_AVX512 + ( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + return BLIS_NOT_YET_IMPLEMENTED; +} + +// LLNN - LUTN +BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB_AVX512 + ( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + return BLIS_NOT_YET_IMPLEMENTED; +} + +// LUNN LUTN +BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB_AVX512 + ( + obj_t* AlphaObj, + obj_t* a, + obj_t* b, + cntx_t* cntx, + cntl_t* cntl + ) +{ + return BLIS_NOT_YET_IMPLEMENTED; +} + +#endif \ No newline at end of file diff --git a/kernels/zen4/bli_kernels_zen4.h b/kernels/zen4/bli_kernels_zen4.h index dd08c507a..26c2ab8c1 100644 --- a/kernels/zen4/bli_kernels_zen4.h +++ b/kernels/zen4/bli_kernels_zen4.h @@ -67,3 +67,10 @@ GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_8x32m ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_4x32m ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_2x32m ) GEMMSUP_KER_PROT( float, s, gemmsup_rv_zen_asm_1x32m ) + + +TRSMSUP_PROT(trsm_small_AVX512) +TRSMSUP_KER_PROT( d, trsm_small_AutXB_AlXB_AVX512 ) +TRSMSUP_KER_PROT( d, trsm_small_XAltB_XAuB_AVX512 ) +TRSMSUP_KER_PROT( d, trsm_small_XAutB_XAlB_AVX512 ) +TRSMSUP_KER_PROT( d, trsm_small_AltXB_AuXB_AVX512 )