3m_sqp conjugate support added

1. 3m_sqp support for A matrix with conjugate_no_transpose and conjugate_transpose added.

AMD-Internal: [CPUPL-1521]
Change-Id: Ie6e5c49cf86f7d3b95d78705cf445e57f20b3d1f
This commit is contained in:
Madan mohan Manokar
2021-07-05 18:40:34 +05:30
parent 4e246b20c7
commit d3542ff0e0
4 changed files with 198 additions and 32 deletions

View File

@@ -666,8 +666,7 @@ void zgemm_
sqp_on = true;
}
#endif
if( ( ( blis_transa == BLIS_TRANSPOSE ) || ( blis_transa == BLIS_NO_TRANSPOSE ) )
&& ( blis_transb == BLIS_NO_TRANSPOSE) && (sqp_on == true))
if( ( blis_transb == BLIS_NO_TRANSPOSE) && ( sqp_on == true ) )
{
//sqp algo is found better for n > 40
if(bli_gemm_sqp(&alphao, &ao, &bo, &betao, &co, NULL, NULL)==BLIS_SUCCESS)

View File

@@ -40,16 +40,31 @@
#define BLIS_LOADFIRST 0
#define MEM_ALLOC 1//malloc performs better than bli_malloc.
#define SET_TRANS(X,Y)\
Y = BLIS_NO_TRANSPOSE;\
if(bli_obj_has_trans( a ))\
{\
Y = BLIS_TRANSPOSE;\
if(bli_obj_has_conj(a))\
{\
Y = BLIS_CONJ_TRANSPOSE;\
}\
}\
else if(bli_obj_has_conj(a))\
{\
Y = BLIS_CONJ_NO_TRANSPOSE;\
}
//Macro for 3m_sqp n loop
#define BLI_SQP_ZGEMM_N(MX)\
int j=0;\
for(; j<=(n-nx); j+= nx)\
{\
status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, isTransA, MX, p_istart, kx, &mem_3m_sqp);\
status = bli_sqp_zgemm_m8( m, nx, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, transa, MX, p_istart, kx, &mem_3m_sqp);\
}\
if(j<n)\
{\
status = bli_sqp_zgemm_m8( m, n-j, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, isTransA, MX, p_istart, kx, &mem_3m_sqp);\
status = bli_sqp_zgemm_m8( m, n-j, k, a, lda, b+(j*ldb), ldb, c+(j*ldc), ldc, alpha_real, beta_real, transa, MX, p_istart, kx, &mem_3m_sqp);\
}
//Macro for sqp_dgemm n loop
@@ -135,7 +150,7 @@ BLIS_INLINE err_t bli_sqp_zgemm( gint_t m,
guint_t ldc,
double alpha,
double beta,
bool isTransA,
trans_t transa,
dim_t nt);
BLIS_INLINE err_t bli_sqp_dgemm( gint_t m,
@@ -181,7 +196,7 @@ err_t bli_gemm_sqp
return BLIS_INVALID_ROW_STRIDE;
}
if(bli_obj_has_conj(a) || bli_obj_has_conj(b))
if(bli_obj_has_conj(b))
{
return BLIS_NOT_YET_IMPLEMENTED;
}
@@ -216,6 +231,9 @@ err_t bli_gemm_sqp
isTransA = true;
}
trans_t transa = BLIS_NO_TRANSPOSE;
SET_TRANS(a,transa)
dim_t nt = bli_thread_get_num_threads(); // get number of threads
double* ap = ( double* )bli_obj_buffer( a );
@@ -237,7 +255,7 @@ err_t bli_gemm_sqp
return BLIS_NOT_YET_IMPLEMENTED;
}
//printf("zsqp ");
return bli_sqp_zgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, isTransA, nt);
return bli_sqp_zgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, alpha_real, beta_real, transa, nt);
}
else if(dt == BLIS_DOUBLE)
{
@@ -254,6 +272,8 @@ err_t bli_gemm_sqp
return BLIS_NOT_YET_IMPLEMENTED;
}
//printf("dsqp ");
// dgemm case only transpose or no-transpose is handled.
// conjugate_transpose and conjugate no transpose are not applicable.
return bli_sqp_dgemm( m, n, k, ap, lda, bp, ldb, cp, ldc, *alpha_cast, *beta_cast, isTransA, nt);
}
@@ -627,7 +647,7 @@ BLIS_INLINE void bli_sqp_zgemm_kx( gint_t m,
guint_t ldb,
double* c,
guint_t ldc,
bool isTransA,
trans_t transa,
double alpha,
double beta,
gint_t mx,
@@ -653,7 +673,7 @@ BLIS_INLINE void bli_sqp_zgemm_kx( gint_t m,
double* pas = as;
/* a matrix real and imag packing and compute. */
bli_3m_sqp_packA_real_imag_sum(a, i, kx+p, lda, par, pai, pas, isTransA, mx, p);
bli_3m_sqp_packA_real_imag_sum(a, i, kx+p, lda, par, pai, pas, transa, mx, p);
double* pcr = cr;
double* pci = ci;
@@ -874,7 +894,7 @@ BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m,
guint_t ldc,
double alpha,
double beta,
bool isTransA,
trans_t transa,
gint_t mx,
gint_t* p_istart,
gint_t kx,
@@ -926,14 +946,14 @@ BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m,
for(; p <= (k-kx); p += kx)
{
bli_sqp_zgemm_kx(m, n, kx, p, a, lda, k, c, ldc,
isTransA, alpha, beta, mx, i, ar, ai, as,
transa, alpha, beta, mx, i, ar, ai, as,
br + p, bi + p, bs + p, cr, ci, w, a_aligned);
}// k loop end
if(p<k)
{
bli_sqp_zgemm_kx(m, n, (k - p), p, a, lda, k, c, ldc,
isTransA, alpha, beta, mx, i, ar, ai, as,
transa, alpha, beta, mx, i, ar, ai, as,
br + p, bi + p, bs + p, cr, ci, w, a_aligned);
}
#else//kloop
@@ -946,7 +966,7 @@ BLIS_INLINE err_t bli_sqp_zgemm_m8( gint_t m,
double* pas = as;
/* a matrix real and imag packing and compute. */
bli_3m_sqp_packA_real_imag_sum(a, i, k, lda, par, pai, pas, isTransA, mx, 0);
bli_3m_sqp_packA_real_imag_sum(a, i, k, lda, par, pai, pas, transa, mx, 0);
double* pcr = cr;
double* pci = ci;
@@ -1035,7 +1055,7 @@ BLIS_INLINE err_t bli_sqp_zgemm(gint_t m,
guint_t ldc,
double alpha_real,
double beta_real,
bool isTransA,
trans_t transa,
dim_t nt)
{
gint_t istart = 0;
@@ -1078,7 +1098,7 @@ BLIS_INLINE err_t bli_sqp_zgemm(gint_t m,
kx = k;
}
// for tn case there is a bug in handling k parts. To be fixed.
if(isTransA==true)
if(transa!=BLIS_NO_TRANSPOSE)
{
kx = k;
}

View File

@@ -1412,19 +1412,18 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
double *par,
double *pai,
double *pas,
bool isTransA,
trans_t transa,
gint_t mx,
gint_t p)
{
__m256d av0, av1, av2, av3;
__m256d tv0, tv1, sum;
__m256d tv0, tv1, sum, zerov;
gint_t poffset = p;
#if KLP
//k = p + k;
#endif
if(mx==8)
{
if(isTransA==false)
if(transa == BLIS_NO_TRANSPOSE)
{
pa = pa +i;
#if KLP
@@ -1432,7 +1431,6 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
#else
p = 0;
#endif
//printf("packA from p_%d to p_%d \n", p, k);
for (; p < k; p += 1)
{
//for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each.
@@ -1462,12 +1460,51 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
pa = pa + lda;
}
}
else
else if(transa == BLIS_CONJ_NO_TRANSPOSE)
{
zerov = _mm256_setzero_pd();
pa = pa +i;
#if KLP
pa = pa + (p*lda);
#else
p = 0;
#endif
for (; p < k; p += 1)
{
//for (int ii = 0; ii < MX8 * 2; ii += 2) //real + imag : Rkernel needs 8 elements each.
av0 = _mm256_loadu_pd(pa);
av1 = _mm256_loadu_pd(pa+4);
av2 = _mm256_loadu_pd(pa+8);
av3 = _mm256_loadu_pd(pa+12);
tv0 = _mm256_permute2f128_pd(av0, av1, 0x20);
tv1 = _mm256_permute2f128_pd(av0, av1, 0x31);
av0 = _mm256_unpacklo_pd(tv0, tv1);
av1 = _mm256_unpackhi_pd(tv0, tv1);
av1 = _mm256_sub_pd(zerov,av1);//negate imaginary component
sum = _mm256_add_pd(av0, av1);
_mm256_storeu_pd(par, av0); par += 4;
_mm256_storeu_pd(pai, av1); pai += 4;
_mm256_storeu_pd(pas, sum); pas += 4;
tv0 = _mm256_permute2f128_pd(av2, av3, 0x20);
tv1 = _mm256_permute2f128_pd(av2, av3, 0x31);
av2 = _mm256_unpacklo_pd(tv0, tv1);
av3 = _mm256_unpackhi_pd(tv0, tv1);
av3 = _mm256_sub_pd(zerov,av3);//negate imaginary component
sum = _mm256_add_pd(av2, av3);
_mm256_storeu_pd(par, av2); par += 4;
_mm256_storeu_pd(pai, av3); pai += 4;
_mm256_storeu_pd(pas, sum); pas += 4;
pa = pa + lda;
}
}
else if(transa == BLIS_TRANSPOSE)
{
gint_t idx = (i/2) * lda;
pa = pa + idx;
#if KLP
//pa = pa + p;
#else
p = 0;
#endif
@@ -1527,18 +1564,79 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
}
}
}
} //mx==8
else//mx==1
{
if(isTransA==false)
else if(transa == BLIS_CONJ_TRANSPOSE)
{
pa = pa + i;
gint_t idx = (i/2) * lda;
pa = pa + idx;
#if KLP
#else
p = 0;
#endif
//A conjugate Transpose case:
for (gint_t ii = 0; ii < BLIS_MX8 ; ii++)
{
gint_t idx = ii * lda;
gint_t sidx;
gint_t pidx = 0;
gint_t max_k = (k*2) - 8;
for (p = poffset; p <= max_k; p += 8)
{
double ar0_ = *(pa + idx + p);
double ai0_ = -(*(pa + idx + p + 1));
double ar1_ = *(pa + idx + p + 2);
double ai1_ = -(*(pa + idx + p + 3));
double ar2_ = *(pa + idx + p + 4);
double ai2_ = -(*(pa + idx + p + 5));
double ar3_ = *(pa + idx + p + 6);
double ai3_ = -(*(pa + idx + p + 7));
sidx = (pidx/2) * BLIS_MX8;
*(par + sidx + ii) = ar0_;
*(pai + sidx + ii) = ai0_;
*(pas + sidx + ii) = ar0_ + ai0_;
sidx = ((pidx+2)/2) * BLIS_MX8;
*(par + sidx + ii) = ar1_;
*(pai + sidx + ii) = ai1_;
*(pas + sidx + ii) = ar1_ + ai1_;
sidx = ((pidx+4)/2) * BLIS_MX8;
*(par + sidx + ii) = ar2_;
*(pai + sidx + ii) = ai2_;
*(pas + sidx + ii) = ar2_ + ai2_;
sidx = ((pidx+6)/2) * BLIS_MX8;
*(par + sidx + ii) = ar3_;
*(pai + sidx + ii) = ai3_;
*(pas + sidx + ii) = ar3_ + ai3_;
pidx += 8;
}
for (; p < (k*2); p += 2)
{
double ar_ = *(pa + idx + p);
double ai_ = -(*(pa + idx + p + 1));
gint_t sidx = (pidx/2) * BLIS_MX8;
*(par + sidx + ii) = ar_;
*(pai + sidx + ii) = ai_;
*(pas + sidx + ii) = ar_ + ai_;
pidx += 2;
}
}
}
} //mx==8
else//mx==1
{
if(transa == BLIS_NO_TRANSPOSE)
{
pa = pa + i;
#if KLP
//pa = pa + (p*lda); done below.. not needed
#else
p = 0;
#endif
//printf(" packAx1 from p_%d to p_%d ",p,k-1);
//A No transpose case:
for (; p < k; p += 1)
{
@@ -1554,12 +1652,33 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
}
}
}
else
else if(transa == BLIS_CONJ_NO_TRANSPOSE)
{
pa = pa + i;
#if KLP
#else
p = 0;
#endif
//A conjuate No transpose case:
for (; p < k; p += 1)
{
gint_t idx = p * lda;
for (gint_t ii = 0; ii < (mx*2) ; ii += 2)
{ //real + imag : Rkernel needs 8 elements each.
double ar_ = *(pa + idx + ii);
double ai_ = -(*(pa + idx + ii + 1));// conjugate: negate imaginary component
*par = ar_;
*pai = ai_;
*pas = ar_ + ai_;
par++; pai++; pas++;
}
}
}
else if(transa == BLIS_TRANSPOSE)
{
gint_t idx = (i/2) * lda;
pa = pa + idx;
#if KLP
//pa = pa + p; done below.. not needed
#else
p = 0;
#endif
@@ -1583,6 +1702,34 @@ void bli_3m_sqp_packA_real_imag_sum(double *pa,
}
}
}
else if(transa == BLIS_CONJ_TRANSPOSE)
{
gint_t idx = (i/2) * lda;
pa = pa + idx;
#if KLP
#else
p = 0;
#endif
//A Transpose case:
for (gint_t ii = 0; ii < mx ; ii++)
{
gint_t idx = ii * lda;
gint_t sidx;
gint_t pidx = 0;
for (p = poffset;p < (k*2); p += 2)
{
double ar0_ = *(pa + idx + p);
double ai0_ = -(*(pa + idx + p + 1));
sidx = (pidx/2) * mx;
*(par + sidx + ii) = ar0_;
*(pai + sidx + ii) = ai0_;
*(pas + sidx + ii) = ar0_ + ai0_;
pidx += 2;
}
}
}
}//mx==1
}

View File

@@ -62,4 +62,4 @@ void bli_prepackA_1(double* pa, double* aPacked, gint_t k, guint_t lda, bool isT
/* Pack real and imaginary parts in separate buffers and also multipy with multiplication factor */
void bli_3m_sqp_packC_real_imag(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double mul, gint_t mx);
void bli_3m_sqp_packB_real_imag_sum(double* pb, guint_t n, guint_t k, guint_t ldb, double* pbr, double* pbi, double* pbs, double mul, gint_t mx);
void bli_3m_sqp_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, bool isTransA, gint_t mx, gint_t p);
void bli_3m_sqp_packA_real_imag_sum(double *pa, gint_t i, guint_t k, guint_t lda, double *par, double *pai, double *pas, trans_t transa, gint_t mx, gint_t p);