diff --git a/kernels/zen/3/bli_trsm_small.c b/kernels/zen/3/bli_trsm_small.c index 5b6df35d7..168fe48d7 100644 --- a/kernels/zen/3/bli_trsm_small.c +++ b/kernels/zen/3/bli_trsm_small.c @@ -1682,6 +1682,99 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b*5));\ ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); +#define BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2));\ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*2));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*3));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*4));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + 2 + cs_b*5));\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + +#define BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_pd((double const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_pd(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + +#define BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_sd((double const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11));\ + ymm3 = _mm256_fmsub_pd(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_pd(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_pd(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_pd(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_pd(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_sd ((double const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_pd(ymm0, ymm15, ymm13); + #ifdef BLIS_DISABLE_TRSM_PREINVERSION #define STRSM_SMALL_DIV_OR_SCALE _mm256_div_ps #endif @@ -1936,6 +2029,22 @@ BLIS_INLINE err_t dtrsm_XAltB_ref b10 += cs_b;\ } +#define BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) \ + for(k = 0; k < k_iter; k++) /*loop for number of GEMM operations*/\ + {\ + /*load 5x1 block of B10*/\ + ymm0 = _mm256_broadcast_ss((float const *)(b10 + 4));\ + xmm5 = _mm_loadu_ps((float const *)(b10));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + /*broadcast 1st row of A01*/\ + ymm2 = _mm256_broadcast_ss((float const *)(a01 + p_lda * 0)); /*A01[0][0]*/\ + ymm3 = _mm256_fmadd_ps(ymm2, ymm0, ymm3);\ +\ + a01 += 1; /*move to next row*/\ + b10 += cs_b;\ + } + /*GEMM block used in strsm small left cases*/ #define BLIS_STRSM_SMALL_GEMM_16mx6n(a10,b01,cs_b,p_lda,k_iter) \ float *b01_prefetch = b01 + 8; \ @@ -3228,6 +3337,280 @@ BLIS_INLINE err_t dtrsm_XAltB_ref ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5));\ ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); +#define BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*2 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*3 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*4 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + cs_b*5 + 6));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + \ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + 4 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 1);\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4));\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*2));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*3));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*4));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + 4 + cs_b*5));\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_loadu_ps((float const *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*2));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*3));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*4));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_broadcast_ss((float *)(b11 + 2 + cs_b*5));\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*3));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*4));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + xmm5 = _mm_broadcast_ss((float *)&zero);\ + xmm5 = _mm_loadl_pi(xmm5,(__m64 *)(b11 + cs_b*5));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm5, 0);\ +\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + +#define BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b)\ + ymm15 = _mm256_broadcast_ss((float const *)&AlphaVal); /*register to hold alpha*/\ +\ + ymm0 = _mm256_broadcast_ss((float const *)b11);\ + ymm3 = _mm256_fmsub_ps(ymm0, ymm15, ymm3);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b));\ + ymm5 = _mm256_fmsub_ps(ymm0, ymm15, ymm5);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*2));\ + ymm7 = _mm256_fmsub_ps(ymm0, ymm15, ymm7);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*3));\ + ymm9 = _mm256_fmsub_ps(ymm0, ymm15, ymm9);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*4));\ + ymm11 = _mm256_fmsub_ps(ymm0, ymm15, ymm11);\ +\ + ymm0 = _mm256_broadcast_ss((float const *)(b11 + cs_b*5));\ + ymm13 = _mm256_fmsub_ps(ymm0, ymm15, ymm13); + /* Load b11 of size 6x8 and multiply with alpha Add the GEMM output and perform inregister transose of b11 @@ -6628,7 +7011,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -6752,7 +7135,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -6869,7 +7252,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -8276,7 +8659,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB if(transa) { - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x = 0;(x + d_nr - 1) < p_lda;x+=d_nr) { ymm0 = _mm256_loadu_pd((double const *)(a01)); ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); @@ -8315,6 +8699,34 @@ BLIS_INLINE err_t bli_dtrsm_small_XAltB_XAuB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_loop_count = p_lda - x; + if(remainder_loop_count >= 4) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_loop_count = remainder_loop_count - 4; + } } else { @@ -8979,7 +9391,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -9094,7 +9506,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -9202,7 +9614,7 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB BLIS_DTRSM_SMALL_GEMM_6nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_DTRSM_SMALL_6x4(AlphaVal,b11,cs_b) + BLIS_PRE_DTRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -10591,7 +11003,8 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB if(transa) { - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x =0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { ymm0 = _mm256_loadu_pd((double const *)(a01)); ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); @@ -10638,6 +11051,34 @@ BLIS_INLINE err_t bli_dtrsm_small_XAutB_XAlB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_loop_count = p_lda - x; + if(remainder_loop_count >= 4) + { + ymm0 = _mm256_loadu_pd((double const *)(a01)); + ymm1 = _mm256_loadu_pd((double const *)(a01 + cs_a)); + ymm2 = _mm256_loadu_pd((double const *)(a01 + cs_a * 2)); + ymm3 = _mm256_loadu_pd((double const *)(a01 + cs_a * 3)); + + ymm4 = _mm256_unpacklo_pd(ymm0, ymm1); + ymm5 = _mm256_unpacklo_pd(ymm2, ymm3); + + ymm6 = _mm256_permute2f128_pd(ymm4,ymm5,0x20); + ymm8 = _mm256_permute2f128_pd(ymm4,ymm5,0x31); + + ymm0 = _mm256_unpackhi_pd(ymm0, ymm1); + ymm1 = _mm256_unpackhi_pd(ymm2, ymm3); + + ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x20); + ymm9 = _mm256_permute2f128_pd(ymm0,ymm1,0x31); + + _mm256_storeu_pd((double *)(ptr_a10_dup), ymm6); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda), ymm7); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*2), ymm8); + _mm256_storeu_pd((double *)(ptr_a10_dup + p_lda*3), ymm9); + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_loop_count = remainder_loop_count - 4; + } } else { @@ -11696,7 +12137,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AltXB_AuXB dim_t p_lda = 4; // packed leading dimension if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { ymm0 = _mm256_loadu_pd((double const *)(a10)); ymm1 = _mm256_loadu_pd((double const *)(a10 + cs_a)); @@ -14530,7 +14971,9 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11 + cs_b * 3), ymm3); //store(B11[0-3][3]) ymm0 = _mm256_loadu_pd((double const *)(b11 + cs_b *4)); - ymm1 = _mm256_loadu_pd((double const *)(b11 + cs_b *5)); + xmm5 = _mm_loadu_pd((double const*)(b11 + cs_b * 5)); + ymm1 = _mm256_insertf128_pd(ymm1, xmm5, 0); + ymm8 = _mm256_fmsub_pd(ymm0, ymm16, ymm4); ymm9 = _mm256_fmsub_pd(ymm1, ymm16, ymm5); @@ -14538,7 +14981,8 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB ymm1 = _mm256_blend_pd(ymm9, ymm1, 0x0C); _mm256_storeu_pd((double *)(b11 + cs_b * 4), ymm0); //store(B11[0-3][2]) - _mm256_storeu_pd((double *)(b11 + cs_b * 5), ymm1); //store(B11[0-3][3]) + xmm5 = _mm256_castpd256_pd128(ymm1); + _mm_storeu_pd((double *)(b11 + cs_b * 5), xmm5); if(transa) dtrsm_AutXB_ref(a11, b11, m_rem, 6, cs_a, cs_b, is_unitdiag); @@ -14585,7 +15029,7 @@ BLIS_INLINE err_t bli_dtrsm_small_AutXB_AlXB _mm256_storeu_pd((double *)(b11), ymm0); //store(B11[0-3][0]) _mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm1); //store(B11[0-3][1]) _mm256_storeu_pd((double *)(b11 + cs_b * 2), ymm2); //store(B11[0-3][2]) - xmm5 = _mm256_castpd256_pd128(ymm3); + xmm5 = _mm256_castpd256_pd128(ymm3); _mm_storeu_pd((double *)(b11 + cs_b * 3), xmm5); if(transa) @@ -15820,7 +16264,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -15903,25 +16347,29 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 1)); m_remainder -=7; } @@ -15941,7 +16389,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16024,25 +16472,18 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -=6; } @@ -16062,7 +16503,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16145,25 +16586,18 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -=5; } @@ -16183,7 +16617,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16266,25 +16700,12 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); m_remainder -=4; } @@ -16304,7 +16725,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16387,25 +16808,29 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 0)); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 0)); m_remainder -=3; } @@ -16425,7 +16850,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16508,25 +16933,23 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); m_remainder -=2; } @@ -16546,7 +16969,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -16629,25 +17052,12 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = STRSM_SMALL_DIV_OR_SCALE(ymm3, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((b11 + cs_b * 4), _mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((b11 + cs_b * 5), _mm256_extractf128_ps(ymm13, 0)); m_remainder -=1; } @@ -18690,7 +19100,8 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB __m128 xmm0, xmm1, xmm2, xmm3; __m128 xmm4, xmm5, xmm6, xmm7; __m128 xmm8, xmm9; - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { xmm0 = _mm_loadu_ps((float const *)(a01)); xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); @@ -18733,6 +19144,33 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_count = p_lda - x; + if(remainder_count >= 4) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a01 += 4*cs_a; + ptr_a10_dup += 4; + remainder_count = remainder_count - 4; + } } else { @@ -18909,7 +19347,7 @@ BLIS_INLINE err_t bli_strsm_small_XAutB_XAlB ymm3 = _mm256_setzero_ps(); ///GEMM implementation starts/// - BLIS_STRSM_SMALL_GEMM_1nx8m(a01,b10,cs_b,p_lda,k_iter) + BLIS_STRSM_SMALL_GEMM_1nx5m(a01,b10,cs_b,p_lda,k_iter) BLIS_PRE_STRSM_SMALL_1N_5M(AlphaVal,b11,cs_b) @@ -19510,7 +19948,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x7(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19601,25 +20039,29 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x7F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x7F); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_store_ss((float *)(b11 + 6),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 1)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 1)); + + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); + _mm_store_ss((float *)(b11 + 6 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 1)); m_remainder -= 7; i += 7; @@ -19640,7 +20082,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x6(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19731,25 +20173,18 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x3F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x3F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storel_pi((__m64 *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_storel_pi((__m64 *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -= 6; i += 6; @@ -19770,7 +20205,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x5(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19861,25 +20296,18 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x1F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x1F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((float *)(b11 + 4),_mm256_extractf128_ps(ymm3, 1)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b),_mm256_extractf128_ps(ymm5, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*2),_mm256_extractf128_ps(ymm7, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*3),_mm256_extractf128_ps(ymm9, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*4),_mm256_extractf128_ps(ymm11, 1)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); + _mm_store_ss((float *)(b11 + 4 + cs_b*5),_mm256_extractf128_ps(ymm13, 1)); m_remainder -= 5; i += 5; @@ -19900,7 +20328,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x4(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -19991,25 +20419,12 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x0F); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x0F); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_storeu_ps((float *)(b11),_mm256_extractf128_ps(ymm3, 0)); + _mm_storeu_ps((float *)(b11 + cs_b),_mm256_extractf128_ps(ymm5, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*2),_mm256_extractf128_ps(ymm7, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*3),_mm256_extractf128_ps(ymm9, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*4),_mm256_extractf128_ps(ymm11, 0)); + _mm_storeu_ps((float *)(b11 + cs_b*5),_mm256_extractf128_ps(ymm13, 0)); m_remainder -= 4; i += 4; @@ -20030,7 +20445,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 8x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x3(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20121,25 +20536,29 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x07); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x07); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); + _mm_store_ss((float *)(b11+2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm3,ymm3), 0)); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + _mm_store_ss((float *)(b11+ 2 + cs_b),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm5,ymm5), 0)); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*2),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm7,ymm7), 0)); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*3),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm9,ymm9), 0)); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*4),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm11,ymm11), 0)); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); + _mm_store_ss((float *)(b11 + 2 + cs_b*5),_mm256_extractf128_ps(_mm256_unpackhi_ps(ymm13,ymm13), 0)); m_remainder -= 3; i += 3; @@ -20160,7 +20579,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x2(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20251,25 +20670,23 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x03); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x03); + xmm5 = _mm256_extractf128_ps(ymm3, 0); + _mm_storel_pi((__m64 *)(b11),xmm5); - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + xmm5 = _mm256_extractf128_ps(ymm5, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm7, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*2),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm9, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*3),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm11, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*4),xmm5); + + xmm5 = _mm256_extractf128_ps(ymm13, 0); + _mm_storel_pi((__m64 *)(b11 + cs_b*5),xmm5); m_remainder -= 2; i += 2; @@ -20290,7 +20707,7 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB BLIS_STRSM_SMALL_GEMM_6nx8m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_STRSM_SMALL_6x8(AlphaVal,b11,cs_b) + BLIS_PRE_STRSM_SMALL_6x1(AlphaVal,b11,cs_b) ///implement TRSM/// @@ -20381,25 +20798,12 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB ymm13 = STRSM_SMALL_DIV_OR_SCALE(ymm13, ymm0); - ymm0 = _mm256_loadu_ps((float const *)b11); - ymm3 = _mm256_blend_ps(ymm0, ymm3, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm5 = _mm256_blend_ps(ymm0, ymm5, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*2)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm7 = _mm256_blend_ps(ymm0, ymm7, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*3)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm9 = _mm256_blend_ps(ymm0, ymm9, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*4)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm11 = _mm256_blend_ps(ymm0, ymm11, 0x01); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b*5)); //B11[0][1] B11[1][1] B11[2][1] B11[3][1] - ymm13 = _mm256_blend_ps(ymm0, ymm13, 0x01); - - _mm256_storeu_ps((float *)b11, ymm3); - _mm256_storeu_ps((float *)(b11 + cs_b), ymm5); - _mm256_storeu_ps((float *)(b11 + cs_b*2), ymm7); - _mm256_storeu_ps((float *)(b11 + cs_b*3), ymm9); - _mm256_storeu_ps((float *)(b11 + cs_b*4), ymm11); - _mm256_storeu_ps((float *)(b11 + cs_b*5), ymm13); + _mm_store_ss((b11 + cs_b * 0), _mm256_extractf128_ps(ymm3, 0)); + _mm_store_ss((b11 + cs_b * 1), _mm256_extractf128_ps(ymm5, 0)); + _mm_store_ss((b11 + cs_b * 2), _mm256_extractf128_ps(ymm7, 0)); + _mm_store_ss((b11 + cs_b * 3), _mm256_extractf128_ps(ymm9, 0)); + _mm_store_ss((b11 + cs_b * 4), _mm256_extractf128_ps(ymm11, 0)); + _mm_store_ss((b11 + cs_b * 5), _mm256_extractf128_ps(ymm13, 0)); m_remainder -= 1; i += 1; @@ -22523,7 +22927,8 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB __m128 xmm4, xmm5, xmm6, xmm7; __m128 xmm8, xmm9; - for(dim_t x =0;x < p_lda;x+=d_nr) + dim_t x = 0; + for(x =0;(x+d_nr-1) < p_lda;x+=d_nr) { xmm0 = _mm_loadu_ps((float const *)(a01)); xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); @@ -22566,6 +22971,32 @@ BLIS_INLINE err_t bli_strsm_small_XAltB_XAuB a01 += d_nr*cs_a; ptr_a10_dup += d_nr; } + dim_t remainder_count = p_lda - x; + if(remainder_count >= 4) + { + xmm0 = _mm_loadu_ps((float const *)(a01)); + xmm1 = _mm_loadu_ps((float const *)(a01 + cs_a)); + xmm2 = _mm_loadu_ps((float const *)(a01 + cs_a * 2)); + xmm3 = _mm_loadu_ps((float const *)(a01 + cs_a * 3)); + + xmm4 = _mm_unpacklo_ps(xmm0, xmm1); + xmm5 = _mm_unpacklo_ps(xmm2, xmm3); + xmm6 = _mm_shuffle_ps(xmm4,xmm5,0x44); + xmm7 = _mm_shuffle_ps(xmm4,xmm5,0xEE); + + xmm0 = _mm_unpackhi_ps(xmm0, xmm1); + xmm1 = _mm_unpackhi_ps(xmm2, xmm3); + xmm8 = _mm_shuffle_ps(xmm0,xmm1,0x44); + xmm9 = _mm_shuffle_ps(xmm0,xmm1,0xEE); + + _mm_storeu_ps((float *)(ptr_a10_dup), xmm6); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda), xmm7); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*2), xmm8); + _mm_storeu_ps((float *)(ptr_a10_dup + p_lda*3), xmm9); + + a01 += 4*cs_a; + ptr_a10_dup += 4; + } } else { @@ -29414,7 +29845,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB dim_t p_lda = 8; // packed leading dimension if(transa) { - for(dim_t x =0;x < m-i+8;x+=p_lda) + for(dim_t x =0;x < m-i-8;x+=p_lda) { ymm0 = _mm256_loadu_ps((float const *)(a10)); ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); @@ -30332,7 +30763,7 @@ BLIS_INLINE err_t bli_strsm_small_AltXB_AuXB __m128 xmm6,xmm7,xmm8,xmm9; if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { xmm0 = _mm_loadu_ps((float const *)(a10)); xmm1 = _mm_loadu_ps((float const *)(a10 + cs_a)); @@ -36293,10 +36724,10 @@ BLIS_INLINE err_t bli_ztrsm_small_XAltB_XAuB BLIS_SET_YMM_REG_ZEROS ///GEMM implementation starts/// - BLIS_ZTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) + BLIS_ZTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_ZTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_ZTRSM_SMALL_2x4(AlphaVal,b11,cs_b) ///implement TRSM/// ////extract a00 @@ -37265,7 +37696,8 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - ymm0 = _mm256_loadu_ps((float const *)(b11));\ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds * two dcomplex elements of b11 cols*/\ @@ -37367,8 +37799,10 @@ BLIS_INLINE void ctrsm_small_pack_diag_element ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ ymm16 = _mm256_permute_ps(ymm16, 0x44);\ \ - ymm0 = _mm256_loadu_ps((float const *)(b11));\ - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1));\ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + xmm1 = _mm_loadu_ps((float const *)(b11 + cs_b * 1));\ + ymm1 = _mm256_insertf128_ps(ymm1, xmm1, 0);\ /*in register transpose * ymm0,ymm1,ymm2 holds * two dcomplex elements of b11 cols*/\ @@ -37513,6 +37947,132 @@ BLIS_INLINE void ctrsm_small_pack_diag_element }\ } +/** + * Multiplies Alpha with one scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + +/** + * Multiplies Alpha with two scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + +/** + * Multiplies Alpha with three scomplex + * element of three column. + */ +#define BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal, b11,cs_b){\ + ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal));\ + ymm16 = _mm256_permute_ps(ymm16, 0x44);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm8 = _mm256_sub_ps(ymm19, ymm8);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm10 = _mm256_sub_ps(ymm19, ymm10);\ + \ + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2));\ + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b*2 + 2));\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0);\ + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1);\ + \ + ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11);\ + ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);\ + ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5);\ + ymm19 = _mm256_mul_ps(ymm19, ymm17);\ + ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19);\ + ymm12 = _mm256_sub_ps(ymm19, ymm12);\ + \ +} + /** * Multiplies Alpha with four scomplex * element of three column. @@ -40496,8 +41056,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB (float const *)(a10 + cs_a)); ymm2 = _mm256_loadu_ps( (float const *)(a10 + cs_a * 2)); - ymm3 = _mm256_loadu_ps( - (float const *)(a10 + cs_a * 3)); + ymm3 = _mm256_broadcast_ss((float const *)&ones); ymm4 = _mm256_shuffle_ps(ymm0, ymm1, 0x44); ymm5 = _mm256_shuffle_ps(ymm2, ymm3, 0x44); @@ -40709,10 +41268,12 @@ BLIS_INLINE err_t bli_ctrsm_small_AutXB_AlXB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); - ymm1 = _mm256_loadu_ps((float const *)(b11 + cs_b *1)); - ymm2 = _mm256_loadu_ps((float const *)(b11 + cs_b *2)); - + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm1 = _mm256_insertf128_ps(ymm1, xmm0, 0); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b*2)); + ymm2 = _mm256_insertf128_ps(ymm2, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); @@ -42321,7 +42882,7 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB dim_t p_lda = 4; if(transa) { - for(dim_t x =0;x < m-i+4;x+=p_lda) + for(dim_t x =0;x < m-i-4;x+=p_lda) { ymm0 = _mm256_loadu_ps((float const *)(a10)); ymm1 = _mm256_loadu_ps((float const *)(a10 + cs_a)); @@ -42360,11 +42921,11 @@ BLIS_INLINE err_t bli_ctrsm_small_AltXB_AuXB { if(transa) { - ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,m_rem); + ctrsm_small_pack_diag_element(is_unitdiag,a11,cs_a,d11_pack,4); } else { - ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,m_rem); + ctrsm_small_pack_diag_element(is_unitdiag,a11,rs_a,d11_pack,4); } } @@ -43556,7 +44117,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); @@ -43664,7 +44225,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); @@ -43763,7 +44324,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack + 2)); @@ -44111,7 +44672,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44120,7 +44684,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44184,12 +44751,13 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB k_iter = (n-n_rem); BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// + ///GEMM implementation starts/// BLIS_CTRSM_SMALL_GEMM_2nx4m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44198,7 +44766,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44261,7 +44830,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44270,7 +44840,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -44486,12 +45057,15 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB k_iter = (n-n_rem); BLIS_SET_S_YMM_REG_ZEROS - ///GEMM implementation starts/// + ///GEMM implementation starts/// BLIS_CTRSM_SMALL_GEMM_1nx4m(a01,b10,cs_b,p_lda,k_iter) ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44530,7 +45104,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44567,7 +45142,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAutB_XAlB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -44994,7 +45570,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x3(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); @@ -45116,7 +45692,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x2(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); @@ -45232,7 +45808,7 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB BLIS_CTRSM_SMALL_GEMM_3nx4m(a01,b10,cs_b,p_lda,k_iter) // Load b11 of size 4x6 and multiply with alpha - BLIS_PRE_CTRSM_SMALL_3x4(AlphaVal,b11,cs_b) + BLIS_PRE_CTRSM_SMALL_3x1(AlphaVal,b11,cs_b) ymm18 = _mm256_setr_ps(1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0); ymm1 = _mm256_broadcast_ps(( __m128 const *)(d11_pack)); @@ -45598,7 +46174,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45607,7 +46186,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45678,7 +46260,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45687,7 +46270,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm0 = _mm_loadu_ps((float const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45753,7 +46337,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -45762,7 +46347,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm19 = _mm256_fmaddsub_ps(ymm18, ymm16, ymm19); ymm8 = _mm256_sub_ps(ymm19, ymm8); - ymm0 = _mm256_loadu_ps((float const *)(b11 + cs_b * 1)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + cs_b)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); ymm19 = _mm256_shuffle_ps(ymm0, ymm0,0xF5); @@ -45984,7 +46570,10 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11 + 2)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 1); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -46026,7 +46615,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm0 = _mm_loadu_ps((float const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm0, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0); @@ -46066,7 +46656,8 @@ BLIS_INLINE err_t bli_ctrsm_small_XAltB_XAuB ymm16 = _mm256_broadcast_ps(( __m128 const *)(&AlphaVal)); ymm16 = _mm256_permute_ps(ymm16, 0x44); - ymm0 = _mm256_loadu_ps((float const *)(b11)); + xmm1 = _mm_loadl_pi(xmm1, (__m64 const *)(b11)); + ymm0 = _mm256_insertf128_ps(ymm0, xmm1, 0); ymm17 = _mm256_shuffle_ps(ymm16, ymm16, 0x11); ymm18 = _mm256_shuffle_ps(ymm0, ymm0, 0xA0);