From 88e44c64e330d77af8f33d541537e23857ea8fb6 Mon Sep 17 00:00:00 2001 From: satish kumar nuggu Date: Thu, 11 Aug 2022 14:44:16 +0530 Subject: [PATCH] Fixed Memory Leaks in TRSM 1. Fixed the memory leaks in corner cases which caused due to extra loads in all datatypes(s,d,c,z). 2. In remainder cases instead of loading required number of elements, loaded extra elements which lead to memory leaks. Fixed memory leaks by restricting number of loads to required number of elements. AMD-Internal: [CPUPL-2280] Change-Id: Ia49a02565e01d5ed05e98090b7773a444587cd8a --- kernels/zen/3/bli_trsm_small.c | 1253 +++++++++++++++++++++++--------- 1 file changed, 922 insertions(+), 331 deletions(-) diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 5b6df35d7..168fe48d7 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -1682,6 +1682,99 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5));\ ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); +#define BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2));\ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*2));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*3));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*4));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*5));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + +#define BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + +#define BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11));\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION #define STRSM_SMALL_DIV_OR_SCALE _mm256_div_ps #endif @@ -1936,6 +2029,22 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + /*GEMM block used in strsm small left cases*/ #define BLIS_STRSM_SMALL_GEMM_16mx6n(a10,b01,cs_b,p_lda,k_iter) \ float *b01_prefetch = b01 + 8; \ @@ -3228,6 +3337,280 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5));\ ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); +#define BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*2 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*3 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*4 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*5 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4));\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*2));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*3));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*4));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*5));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*2));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*3));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*4));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*5));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)b11);\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + /* Load b11 of size 6x8 and multiply with alpha Add the GEMM output and perform inregister transose of b11 @@ -6628,7 +7011,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -6752,7 +7135,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -6869,7 +7252,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -8276,7 +8659,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB if(transa) { - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x = 0;(x + d_nr - 1) < p_lda;x+=d_nr) { ymm0 = _mm256_loadu_pd((double const *)(a01)); ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); @@ -8315,6 +8699,34 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_loop_count = p_lda - x; + if(remainder_loop_count >= 4) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_loop_count = remainder_loop_count - 4; + } } else { @@ -8979,7 +9391,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -9094,7 +9506,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -9202,7 +9614,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -10591,7 +11003,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB if(transa) { - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x =0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { ymm0 = _mm256_loadu_pd((double const *)(a01)); ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); @@ -10638,6 +11051,34 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_loop_count = p_lda - x; + if(remainder_loop_count >= 4) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_loop_count = remainder_loop_count - 4; + } } else { @@ -11696,7 +12137,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB dim_t p_lda = 4; // packed leading dimension if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { ymm0 = _mm256_loadu_pd((double const *)(a10)); ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); @@ -14530,7 +14971,9 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 5)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); @@ -14538,7 +14981,8 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + xmm5 = _mm256_castpd256_pd128(ymm1); + _mm_storeu_pd((double *)(b11 + cs_b * 5), xmm5); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b, is_unitdiag); @@ -14585,7 +15029,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_castpd256_pd128(ymm3); + xmm5 = _mm256_castpd256_pd128(ymm3); _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); if(transa) @@ -15820,7 +16264,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -15903,25 +16347,29 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 1)); m_remainder -=7; } @@ -15941,7 +16389,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16024,25 +16472,18 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -=6; } @@ -16062,7 +16503,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16145,25 +16586,18 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -=5; } @@ -16183,7 +16617,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16266,25 +16700,12 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); m_remainder -=4; } @@ -16304,7 +16725,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16387,25 +16808,29 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 0)); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 0)); m_remainder -=3; } @@ -16425,7 +16850,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16508,25 +16933,23 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); m_remainder -=2; } @@ -16546,7 +16969,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16629,25 +17052,12 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((b11 + cs_b * 4), _mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((b11 + cs_b * 5), _mm256_extractf128_ps(ymm13, 0)); m_remainder -=1; } @@ -18690,7 +19100,8 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB __m128 xmm0, xmm1, xmm2, xmm3; __m128 xmm4, xmm5, xmm6, xmm7; __m128 xmm8, xmm9; - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { xmm0 = _mm_loadu_ps((float const *)(a01)); xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); @@ -18733,6 +19144,33 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_count = p_lda - x; + if(remainder_count >= 4) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_count = remainder_count - 4; + } } else { @@ -18909,7 +19347,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_5M(AlphaVal,b11,cs_b) @@ -19510,7 +19948,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19601,25 +20039,29 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 1)); m_remainder -= 7; i += 7; @@ -19640,7 +20082,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19731,25 +20173,18 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -= 6; i += 6; @@ -19770,7 +20205,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19861,25 +20296,18 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -= 5; i += 5; @@ -19900,7 +20328,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19991,25 +20419,12 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); m_remainder -= 4; i += 4; @@ -20030,7 +20445,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20121,25 +20536,29 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 0)); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 0)); m_remainder -= 3; i += 3; @@ -20160,7 +20579,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20251,25 +20670,23 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); m_remainder -= 2; i += 2; @@ -20290,7 +20707,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20381,25 +20798,12 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((b11 + cs_b * 4), _mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((b11 + cs_b * 5), _mm256_extractf128_ps(ymm13, 0)); m_remainder -= 1; i += 1; @@ -22523,7 +22927,8 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB __m128 xmm4, xmm5, xmm6, xmm7; __m128 xmm8, xmm9; - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { xmm0 = _mm_loadu_ps((float const *)(a01)); xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); @@ -22566,6 +22971,32 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_count = p_lda - x; + if(remainder_count >= 4) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a01 += 4*cs_a; + ptr_a10_dup += 4; + } } else { @@ -29414,7 +29845,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB dim_t p_lda = 8; // packed leading dimension if(transa) { - for(dim_t x =0;x < m-i+8;x+=p_lda) + for(dim_t x =0;x < m-i-8;x+=p_lda) { ymm0 = _mm256_loadu_ps((float const *)(a10)); ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); @@ -30332,7 +30763,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB __m128 xmm6,xmm7,xmm8,xmm9; if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { xmm0 = _mm_loadu_ps((float const *)(a10)); xmm1 = _mm_loadu_ps((float const *)(a10 + cs_a)); @@ -36293,10 +36724,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -37265,7 +37696,8 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - ymm0 = _mm256_loadu_ps((float const *)(b11));\ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds * two dcomplex elements of b11 cols*/\ @@ -37367,8 +37799,10 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - ymm0 = _mm256_loadu_ps((float const *)(b11));\ - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + xmm1 = _mm_loadu_ps((float const *)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_ps(ymm1, xmm1, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds * two dcomplex elements of b11 cols*/\ @@ -37513,6 +37947,132 @@ BLIS_INLINE void ctrsm_small_pack_diag_element }\ } +/** + * Multiplies Alpha with one scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + +/** + * Multiplies Alpha with two scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + +/** + * Multiplies Alpha with three scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b*2 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + /** * Multiplies Alpha with four scomplex * element of three column. @@ -40496,8 +41056,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB (float const *)(a10 + cs_a)); ymm2 = _mm256_loadu_ps( (float const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_ps( - (float const *)(a10 + cs_a * 3)); + ymm3 = _mm256_broadcast_ss((float const *)&ones); ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); @@ -40709,10 +41268,12 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); - + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -42321,7 +42882,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB dim_t p_lda = 4; if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { ymm0 = _mm256_loadu_ps((float const *)(a10)); ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); @@ -42360,11 +42921,11 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB { if(transa) { - ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,4); } else { - ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,4); } } @@ -43556,7 +44117,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); @@ -43664,7 +44225,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); @@ -43763,7 +44324,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); @@ -44111,7 +44672,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44120,7 +44684,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44184,12 +44751,13 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB k_iter = (n-n_rem); BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// + ///GEMM implementation starts/// BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44198,7 +44766,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44261,7 +44830,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44270,7 +44840,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44486,12 +45057,15 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB k_iter = (n-n_rem); BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// + ///GEMM implementation starts/// BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44530,7 +45104,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44567,7 +45142,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44994,7 +45570,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); @@ -45116,7 +45692,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); @@ -45232,7 +45808,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); @@ -45598,7 +46174,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45607,7 +46186,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45678,7 +46260,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45687,7 +46270,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45753,7 +46337,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45762,7 +46347,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45984,7 +46570,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -46026,7 +46615,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -46066,7 +46656,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);