Fixed ztrsm computational failure

- Fixed memory access for edge cases such that
  all load are within memory boundary only.

- Corrected ztrsm utility APIs for dcomplex
  multiplication and division.

AMD-Internal: [CPUPL-2093]
Change-Id: Ib2c65e7921f6391b530cd20d6ea6b50f24bd705e
This commit is contained in:
Harsh Dave
2022-03-30 07:16:24 -05:00
committed by Dipal M Zambare
parent 0976ed9ce5
commit 015bcb88d4

View File

@@ -4,7 +4,7 @@
An object-based framework for developing high-performance BLAS-like
libraries.
Copyright (C) 2018-2021, Advanced Micro Devices, Inc. All rights reserved.
Copyright (C) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
@@ -3891,33 +3891,20 @@ err_t bli_trsm_small
*/
#define DCOMPLEX_INV(a, b) {\
a.real = b.real;\
a.imag = (b.imag * -1.0);\
/*Compute denominator eliminating imaginary component*/\
double dnm = (b.real * b.real);\
/*multiply two times with -1 for correct result as
* dcomplex number with positive imaginary part will
* invert the sign if not multiplied twice with -1*/\
dnm += ((-1.0 * (b.imag * b.imag)) * -1.0);\
/*Compute the final result by dividing real and imag part by dnm*/\
a.real /= dnm;\
a.imag /= dnm;\
/* dcomplex inva = {1.0, 0.0};*/\
a.real = 1.0;\
a.imag = 0.0;\
bli_zinvscals(b, a);\
}
#define DCOMPLEX_MUL(a, b, c) {\
double real = a.real * b.real;\
real += ((a.imag * b.imag) * -1.0);\
double imag = (a.real * b.imag);\
imag += (a.imag * b.real);\
c.real = real;\
c.imag = imag;\
c.real = b.real;\
c.imag = b.imag;\
bli_zscals(a,c);\
}
#define DCOMPLEX_DIV(a, b){\
double dnm = b.real * b.real;\
dnm += (-1.0 * (b.imag * (b.imag * -1.0) ));\
a.real /= dnm;\
a.imag /= dnm;\
bli_zinvscals(b,a); \
}
@@ -3946,11 +3933,8 @@ err_t bli_trsm_small
#define ZTRSM_DIAG_ELE_EVAL_OPS(a,b,c){\
if(!is_unitdiag)\
{\
a.real = b.real;\
a.imag = (b.imag * -1.0);\
DCOMPLEX_MUL(c, a, c)\
DCOMPLEX_DIV(c, b)\
}\
bli_zinvscals(b, c);\
}\
}
#endif
@@ -4299,6 +4283,213 @@ BLIS_INLINE err_t ztrsm_AuXB_ref
_mm256_storeu_pd((double *)(b11 + cs_b * 1), ymm9);\
}
#define BLIS_ZTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter) {\
double *tptr = (double *)b01;\
if(conjtransa) {\
ymm16 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\
for(k = 0; k< k_iter; k++) \
{ \
ymm0 = _mm256_loadu_pd((double const *)(a10)); \
xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\
ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \
ymm0 = _mm256_mul_pd(ymm0, ymm16);\
ymm1 = _mm256_mul_pd(ymm1, ymm16);\
\
ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \
ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \
\
ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\
ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\
\
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \
\
ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\
ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\
\
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \
\
ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\
ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\
\
ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\
ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\
\
tptr += 2; \
a10 += p_lda; \
}\
}\
else {\
for(k = 0; k< k_iter; k++) \
{ \
ymm0 = _mm256_loadu_pd((double const *)(a10)); \
xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\
ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \
ymm2 = _mm256_broadcast_sd((double const *)(tptr)); \
ymm3 = _mm256_broadcast_sd((double const *)(tptr + 1)); \
\
ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\
ymm11 = _mm256_fmadd_pd(ymm1, ymm2, ymm11);\
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\
\
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2)); \
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 1 * 2 + 1)); \
\
ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\
ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\
\
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b *2 * 2)); \
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 2 + 1)); \
\
ymm10 = _mm256_fmadd_pd(ymm0, ymm2, ymm10);\
ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\
\
ymm14 = _mm256_fmadd_pd(ymm0, ymm3, ymm14);\
ymm15 = _mm256_fmadd_pd(ymm1, ymm3, ymm15);\
\
tptr += 2; \
a10 += p_lda; \
}\
}\
ymm4 = _mm256_permute_pd(ymm4, 0x5);\
ymm5 = _mm256_permute_pd(ymm5, 0x5);\
ymm6 = _mm256_permute_pd(ymm6, 0x5);\
ymm7 = _mm256_permute_pd(ymm7, 0x5);\
ymm14 = _mm256_permute_pd(ymm14, 0x5);\
ymm15 = _mm256_permute_pd(ymm15, 0x5);\
\
ymm8 = _mm256_addsub_pd(ymm8, ymm4);\
ymm11 = _mm256_addsub_pd(ymm11, ymm5);\
ymm9 = _mm256_addsub_pd(ymm9, ymm6);\
ymm12 = _mm256_addsub_pd(ymm12, ymm7);\
ymm10 = _mm256_addsub_pd(ymm10, ymm14);\
ymm13 = _mm256_addsub_pd(ymm13, ymm15);\
}
#define BLIS_ZTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b,p_lda,k_iter) {\
double *tptr = (double * )b01;\
if(conjtransa) {\
ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\
for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\
{\
ymm0 = _mm256_loadu_pd((double const *)(a10));\
xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\
ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \
ymm0 = _mm256_mul_pd(ymm0, ymm18);\
ymm1 = _mm256_mul_pd(ymm1, ymm18);\
\
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \
\
ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\
ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \
\
ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\
ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\
tptr += 2; /*move to next row of B*/\
a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\
}\
}\
else {\
for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\
{\
ymm0 = _mm256_loadu_pd((double const *)(a10));\
xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\
ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \
\
ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\
ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1)); \
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 1 + 1)); \
\
ymm9 = _mm256_fmadd_pd(ymm0, ymm2, ymm9);\
ymm13 = _mm256_fmadd_pd(ymm1, ymm2, ymm13);\
ymm6 = _mm256_fmadd_pd(ymm0, ymm3, ymm6);\
ymm7 = _mm256_fmadd_pd(ymm1, ymm3, ymm7);\
tptr += 2; /*move to next row of B*/\
a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\
}\
}\
ymm4 = _mm256_permute_pd(ymm4, 0x5);\
ymm5 = _mm256_permute_pd(ymm5, 0x5);\
ymm6 = _mm256_permute_pd(ymm6, 0x5);\
ymm7 = _mm256_permute_pd(ymm7, 0x5);\
\
ymm8 = _mm256_addsub_pd(ymm8, ymm4);\
ymm12 = _mm256_addsub_pd(ymm12, ymm5);\
ymm9 = _mm256_addsub_pd(ymm9, ymm6);\
ymm13 = _mm256_addsub_pd(ymm13, ymm7);\
}
#define BLIS_ZTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b,p_lda,k_iter) {\
double *tptr = (double *)b01;\
if(conjtransa) {\
ymm18 = _mm256_setr_pd(1.0, -1.0, 1.0, -1.0);\
for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\
{\
ymm0 = _mm256_loadu_pd((double const *)(a10));\
xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\
ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \
ymm0 = _mm256_mul_pd(ymm0, ymm18);\
ymm1 = _mm256_mul_pd(ymm1, ymm18);\
\
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \
\
ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\
ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\
\
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\
tptr += 2; /*move to next row of B*/\
a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\
}\
}\
else {\
for(k = 0; k< k_iter; k++) /*loop for number of GEMM operations*/\
{\
ymm0 = _mm256_loadu_pd((double const *)(a10));\
xmm4 = _mm_loadu_pd((double const *)(a10 + 2));\
ymm1 = _mm256_insertf128_pd(ymm1, xmm4, 0); \
ymm2 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0));\
ymm3 = _mm256_broadcast_sd((double const *)(tptr + cs_b * 2 * 0 + 1)); \
\
ymm8 = _mm256_fmadd_pd(ymm0, ymm2, ymm8);\
ymm12 = _mm256_fmadd_pd(ymm1, ymm2, ymm12);\
\
ymm4 = _mm256_fmadd_pd(ymm0, ymm3, ymm4);\
ymm5 = _mm256_fmadd_pd(ymm1, ymm3, ymm5);\
tptr += 2; /*move to next row of B*/\
a10 += p_lda;/*pointer math to calculate next block of A for GEMM*/\
}\
}\
ymm4 = _mm256_permute_pd(ymm4, 0x5);\
ymm5 = _mm256_permute_pd(ymm5, 0x5);\
ymm8 = _mm256_addsub_pd(ymm8, ymm4);\
ymm12 = _mm256_addsub_pd(ymm12, ymm5);\
}
/**
* Performs GEMM operation.
* Two elements of column in ymm0
@@ -31943,75 +32134,160 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB
if(m_rem == 3)
{
dim_t p_lda = 4;
if(transa)
{
for(dim_t x = 0; x < i; x += p_lda)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
ymm10 = _mm256_loadu_pd((double const *)
(a10 + 2));
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
ymm11 = _mm256_loadu_pd((double const *)
(a10 + 2 + cs_a));
if(transa)
{
dim_t x = 0;
for(x = 0; (x+3) < i; x += p_lda)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
ymm10 = _mm256_loadu_pd((double const *)
(a10 + 2));
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
ymm11 = _mm256_loadu_pd((double const *)
(a10 + 2 + cs_a));
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20);
ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31);
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20);
ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2), ymm8);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*3), ymm9);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2), ymm8);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*3), ymm9);
ymm0 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a));
ymm10 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a + 2));
ymm0 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a));
ymm10 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a + 2));
ymm1 = _mm256_set_pd(1, 1, 1, 1);
ymm1 = _mm256_loadu_pd((double const *)(a10
+ 3 * cs_a));
ymm11 = _mm256_loadu_pd((double const *)(a10
+ 3 * cs_a + 2));
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20);
ymm9 = _mm256_permute2f128_pd(ymm10,ymm1,0x31);
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20);
ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup + 2),
ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda + 2), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2 + 2), ymm8);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*3 + 2), ymm9);
_mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda + 2), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2 + 2), ymm8);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*3 + 2), ymm9);
a10 += p_lda;
ptr_a10_dup += p_lda * p_lda;
}
a10 += p_lda;
ptr_a10_dup += p_lda * p_lda;
}
for(; (x+2) < i; x += 3)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
xmm4 = _mm_loadu_pd((double const *)
(a10 + 2));
ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0);
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
xmm4 = _mm_loadu_pd((double const *)
(a10 + 2 + cs_a));
ymm11 = _mm256_insertf128_pd(ymm11, xmm4, 0);
}
else
{
for(dim_t x=0;x<i;x++)
{
ymm0 = _mm256_loadu_pd((double const *)
(a10 + rs_a * x));
_mm256_storeu_pd((double *)
(ptr_a10_dup + p_lda * x), ymm0);
ymm0 = _mm256_loadu_pd((double const *)
(a10 + rs_a * x + 2));
_mm256_storeu_pd((double *)
(ptr_a10_dup + p_lda * x + 2),
ymm0);
}
}
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2), ymm8);
ymm0 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a));
xmm4 = _mm_loadu_pd((double const *)(a10
+ 2 * cs_a + 2));
ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0);
ymm1 = _mm256_set_pd(1, 1, 1, 1);
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20);
_mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda + 2), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2 + 2), ymm8);
a10 += 3;
ptr_a10_dup += p_lda * p_lda;
}
for(; (x+1) < i; x += 2)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
ymm0 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a));
ymm1 = _mm256_set_pd(1, 1, 1, 1);
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda + 2), ymm7);
a10 += 2;
ptr_a10_dup += p_lda * p_lda;
}
for(; x < i; x += 1)
{
xmm4 = _mm_loadu_pd((double const *)(a10));
xmm5 = _mm_loadu_pd((double const *)
(a10 + cs_a));
_mm_storeu_pd((double *)(ptr_a10_dup), xmm4);
_mm_storeu_pd((double *)(ptr_a10_dup + 1), xmm5);
xmm4 = _mm_loadu_pd((double const *)(a10
+ 2 * cs_a));
_mm_storeu_pd((double *)(ptr_a10_dup + 2), xmm4);
a10 += 1;
ptr_a10_dup += p_lda * p_lda;
}
}
else
{
for(dim_t x=0;x<i;x++)
{
ymm0 = _mm256_loadu_pd((double const *)
(a10 + rs_a * x));
_mm256_storeu_pd((double *)
(ptr_a10_dup + p_lda * x), ymm0);
xmm4 = _mm_loadu_pd((double const *)
(a10 + rs_a * x + 2));
_mm_storeu_pd((double *)
(ptr_a10_dup + p_lda * x + 2),
xmm4);
}
}
//cols
for(j = 0; (j+d_nr-1) < n; j += d_nr)
{
@@ -32023,7 +32299,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB
BLIS_SET_YMM_REG_ZEROS
///GEMM code begins///
BLIS_ZTRSM_SMALL_GEMM_4mx3n(a10,b01,cs_b,p_lda,k_iter)
BLIS_ZTRSM_SMALL_GEMM_3mx3n(a10,b01,cs_b,p_lda,k_iter)
///GEMM code ends///
ymm16 = _mm256_broadcast_pd((__m128d const *)
(&AlphaVal));
@@ -32119,7 +32395,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB
if(2 == n_rem)
{
///GEMM code begins///
BLIS_ZTRSM_SMALL_GEMM_4mx2n(a10,b01,cs_b,
BLIS_ZTRSM_SMALL_GEMM_3mx2n(a10,b01,cs_b,
p_lda,k_iter)
BLIS_PRE_ZTRSM_SMALL_3M_2N(AlphaVal,b11,cs_b)
@@ -32136,7 +32412,7 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB
else if(1 == n_rem)
{
///GEMM code begins///
BLIS_ZTRSM_SMALL_GEMM_4mx1n(a10,b01,cs_b,
BLIS_ZTRSM_SMALL_GEMM_3mx1n(a10,b01,cs_b,
p_lda,k_iter)
BLIS_PRE_ZTRSM_SMALL_3M_1N(AlphaVal,b11,cs_b)
@@ -32298,34 +32574,35 @@ BLIS_INLINE err_t bli_ztrsm_small_AutXB_AlXB
{
dim_t p_lda = 2; // packed leading dimension
if(transa)
{
for(dim_t x = 0; x < i; x += p_lda)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
{
dim_t x = 0;
for(x = 0; (x + 1) < i; x += p_lda)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
_mm_storeu_pd((double *)(ptr_a10_dup),
_mm256_extractf128_pd(ymm0, 0));
_mm_storeu_pd((double *)(ptr_a10_dup +
p_lda), _mm256_extractf128_pd(ymm0, 1));
a10 += p_lda;
ptr_a10_dup += p_lda * p_lda;
}
for(; x < i; x += 1)
{
xmm4 = _mm_loadu_pd((double const *)(a10));
_mm_storeu_pd((double *)(ptr_a10_dup), xmm4);
a10 += 1;
ptr_a10_dup += 1;
}
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
a10 += p_lda;
ptr_a10_dup += p_lda * p_lda;
}
}
}
else
{
for(dim_t x=0;x<i;x++)
{
ymm0 = _mm256_loadu_pd((double const *)
(a10 + rs_a * x));
_mm256_storeu_pd((double *)
(ptr_a10_dup + p_lda * x), ymm0);
xmm4 = _mm_loadu_pd((double const *)
(a10 + rs_a *x));
_mm_storeu_pd((double *)
(ptr_a10_dup + p_lda *x), xmm4);
}
}
//cols
@@ -33083,73 +33360,158 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB
dim_t p_lda = 4;
if(transa)
{
for(dim_t x = 0; x < m-m_remainder; x += p_lda)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
ymm10 = _mm256_loadu_pd((double const *)
(a10 + 2));
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
ymm11 = _mm256_loadu_pd((double const *)
(a10 + 2 + cs_a));
dim_t x = 0;
for(x = 0; (x+3) < m-m_remainder; x += p_lda)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
ymm10 = _mm256_loadu_pd((double const *)
(a10 + 2));
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
ymm11 = _mm256_loadu_pd((double const *)
(a10 + 2 + cs_a));
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20);
ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31);
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20);
ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2), ymm8);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*3), ymm9);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2), ymm8);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*3), ymm9);
ymm0 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a));
ymm10 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a + 2));
ymm0 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a));
ymm10 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a + 2));
ymm1 = _mm256_set_pd(1, 1, 1, 1);
ymm1 = _mm256_loadu_pd((double const *)(a10
+ 3 * cs_a));
ymm11 = _mm256_loadu_pd((double const *)(a10
+ 3 * cs_a + 2));
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20);
ymm9 = _mm256_permute2f128_pd(ymm10,ymm1,0x31);
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20);
ymm9 = _mm256_permute2f128_pd(ymm10,ymm11,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup + 2),
ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda + 2), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2 + 2), ymm8);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*3 + 2), ymm9);
_mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda + 2), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2 + 2), ymm8);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*3 + 2), ymm9);
a10 += p_lda;
ptr_a10_dup += p_lda * p_lda;
}
a10 += p_lda;
ptr_a10_dup += p_lda * p_lda;
}
for(; (x+2) < m-m_remainder; x += 3)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
xmm4 = _mm_loadu_pd((double const *)
(a10 + 2));
ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0);
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
xmm4 = _mm_loadu_pd((double const *)
(a10 + 2 + cs_a));
ymm11 = _mm256_insertf128_pd(ymm11, xmm4, 0);
}
else
{
for(dim_t x=0;x < m-m_remainder;x++)
{
ymm0 = _mm256_loadu_pd((double const *)
(a10 + rs_a * x));
_mm256_storeu_pd((double *)
(ptr_a10_dup + p_lda * x), ymm0);
ymm0 = _mm256_loadu_pd((double const *)
(a10 + rs_a * x + 2));
_mm256_storeu_pd((double *)
(ptr_a10_dup + p_lda * x + 2),
ymm0);
}
}
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm11,0x20);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2), ymm8);
ymm0 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a));
xmm4 = _mm_loadu_pd((double const *)(a10
+ 2 * cs_a + 2));
ymm10 = _mm256_insertf128_pd(ymm10, xmm4, 0);
ymm1 = _mm256_set_pd(1, 1, 1, 1);
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
ymm8 = _mm256_permute2f128_pd(ymm10,ymm1,0x20);
_mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda + 2), ymm7);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda*2 + 2), ymm8);
a10 += 3;
ptr_a10_dup += p_lda * p_lda;
}
for(; (x+1) < m-m_remainder; x += 2)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
ymm0 = _mm256_loadu_pd((double const *)(a10
+ 2 * cs_a));
ymm1 = _mm256_set_pd(1, 1, 1, 1);
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup + 2), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda + 2), ymm7);
a10 += 2;
ptr_a10_dup += p_lda * p_lda;
}
for(; x < m-m_remainder; x += 1)
{
xmm4 = _mm_loadu_pd((double const *)(a10));
xmm5 = _mm_loadu_pd((double const *)
(a10 + cs_a));
_mm_storeu_pd((double *)(ptr_a10_dup), xmm4);
_mm_storeu_pd((double *)(ptr_a10_dup + 1), xmm5);
xmm4 = _mm_loadu_pd((double const *)(a10
+ 2 * cs_a));
_mm_storeu_pd((double *)(ptr_a10_dup + 2), xmm4);
a10 += 1;
ptr_a10_dup += p_lda * p_lda;
}
}
else
{
for(dim_t x=0;x < m-m_remainder;x++)
{
ymm0 = _mm256_loadu_pd((double const *)
(a10 + rs_a * x));
_mm256_storeu_pd((double *)
(ptr_a10_dup + p_lda * x), ymm0);
xmm4 = _mm_loadu_pd((double const *)
(a10 + rs_a * x + 2));
_mm_storeu_pd((double *)
(ptr_a10_dup + p_lda * x + 2),
xmm4);
}
}
//cols
for(j = (n - d_nr); (j + 1) > 0; j -= d_nr)
{
@@ -33429,37 +33791,38 @@ BLIS_INLINE err_t bli_ztrsm_small_AltXB_AuXB
}
else if(m_remainder == 1)
{
dim_t p_lda = 2; // packed leading dimension
if(transa)
{
for(dim_t x = 0; x < m-m_remainder; x += p_lda)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
ymm1 = _mm256_loadu_pd((double const *)
(a10 + cs_a));
ymm6 = _mm256_permute2f128_pd(ymm0,ymm1,0x20);
ymm7 = _mm256_permute2f128_pd(ymm0,ymm1,0x31);
_mm256_storeu_pd((double *)(ptr_a10_dup), ymm6);
_mm256_storeu_pd((double *)(ptr_a10_dup +
p_lda), ymm7);
a10 += p_lda;
ptr_a10_dup += p_lda * p_lda;
}
}
else
{
for(dim_t x=0;x<m-m_remainder;x++)
{
ymm0 = _mm256_loadu_pd((double const *)
(a10 + rs_a * x));
_mm256_storeu_pd((double *)
(ptr_a10_dup + p_lda * x), ymm0);
}
}
dim_t p_lda = 2; // packed leading dimension
if(transa)
{
dim_t x = 0;
for(x = 0; x < m-m_remainder; x += p_lda)
{
ymm0 = _mm256_loadu_pd((double const *)(a10));
_mm_storeu_pd((double *)(ptr_a10_dup),
_mm256_extractf128_pd(ymm0, 0));
_mm_storeu_pd((double *)(ptr_a10_dup +
p_lda), _mm256_extractf128_pd(ymm0, 1));
a10 += p_lda;
ptr_a10_dup += p_lda * p_lda;
}
for(; x < m - m_remainder; x += 1)
{
xmm4 = _mm_loadu_pd((double const *)(a10));
_mm_storeu_pd((double *)(ptr_a10_dup), xmm4);
a10 += 1;
ptr_a10_dup += 1;
}
}
else
{
for(dim_t x=0;x<m-m_remainder;x++)
{
xmm4 = _mm_loadu_pd((double const *)
(a10 + rs_a *x));
_mm_storeu_pd((double *)
(ptr_a10_dup + p_lda *x), xmm4);
}
}
//cols
for(j = (n - d_nr); (j + 1) > 0; j -= d_nr)
{